[llvm] [RISCV] Support scalable vectors for the zvqdotq lowering paths (PR #140922)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed May 21 08:58:53 PDT 2025
https://github.com/preames created https://github.com/llvm/llvm-project/pull/140922
This was an oversight in the original patch series. Without this change,
the newly added tests fail assertions.
>From 8e875e13afda56300ce19503139f5ba8cffbf8d4 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 21 May 2025 08:32:49 -0700
Subject: [PATCH 1/2] [RISCV] Support scalable vectors for the zvqdotq lowering
paths
This was an oversight in the original patch series. Without this change,
the newly added tests fail assertions.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 21 +-
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 589 ++++++++++++++++++
2 files changed, 601 insertions(+), 9 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1158499718737..73798b899e9ff 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
assert(VT == Op1.getSimpleValueType() &&
VT.getVectorElementType() == MVT::i32);
- assert(VT.isFixedLengthVector());
- MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
- SDValue Passthru = convertToScalableVector(
- ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
-
+ SDValue Passthru = DAG.getConstant(0, DL, VT);
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+ }
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
{Op0, Op1, Passthru, Mask, VL});
- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+ if (VT.isFixedLengthVector())
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+ return LocalAccum;
}
static MVT getQDOTXResultType(MVT OpVT) {
@@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
EVT AVT = A.getValueType();
EVT BVT = B.getValueType();
assert(AVT.getVectorElementType() == BVT.getVectorElementType());
- if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+ if (AVT.getVectorMinNumElements() > BVT.getVectorMinNumElements()) {
std::swap(A, B);
std::swap(AVT, BVT);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
new file mode 100644
index 0000000000000..d811cdb5e444d
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -0,0 +1,589 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+
+define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdot_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v20, v10
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdot_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdot_vx_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: li a0, 23
+; CHECK-NEXT: vwmul.vx v8, v16, a0
+; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v16
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, splat (i32 23)
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdot_vx_constant_swapped(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdot_vx_constant_swapped:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: li a0, 23
+; CHECK-NEXT: vwmul.vx v8, v16, a0
+; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v16
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> splat (i32 23), %a.sext
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdotu_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; NODOT-NEXT: vwmulu.vv v12, v8, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma
+; NODOT-NEXT: vwredsumu.vs v8, v12, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, %b.zext
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdotu_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vqdotu_vx_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vzext.vf2 v16, v8
+; CHECK-NEXT: li a0, 123
+; CHECK-NEXT: vwmulu.vx v8, v16, a0
+; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v16
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, splat (i32 123)
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdotsu_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vzext.vf2 v20, v10
+; NODOT-NEXT: vwmulsu.vv v8, v16, v20
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotsu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.zext
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vqdotsu_vv_swapped(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vzext.vf2 v20, v10
+; NODOT-NEXT: vwmulsu.vv v8, v16, v20
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotsu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %b.zext, %a.sext
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vdotqsu_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vdotqsu_vx_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: li a0, 123
+; CHECK-NEXT: vwmul.vx v8, v16, a0
+; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v16
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, splat (i32 123)
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @vdotqus_vx_constant(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: vdotqus_vx_constant:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vzext.vf2 v16, v8
+; CHECK-NEXT: li a0, -23
+; CHECK-NEXT: vmv.v.x v20, a0
+; CHECK-NEXT: vwmulsu.vv v8, v20, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v16, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v16
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, splat (i32 -23)
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ ret i32 %res
+}
+
+define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: reduce_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vsext.vf4 v16, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v16, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT32-NEXT: vmv.v.i v10, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdot.vx v10, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v10, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT64-NEXT: vmv.v.i v10, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdot.vx v10, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v10, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
+entry:
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
+ ret i32 %res
+}
+
+define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: reduce_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vzext.vf4 v16, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v16, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT32-NEXT: vmv.v.i v10, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdotu.vx v10, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v10, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT64-NEXT: vmv.v.i v10, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdotu.vx v10, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v10, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
+entry:
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
+ ret i32 %res
+}
+
+define i32 @vqdot_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v24, v10
+; NODOT-NEXT: vwmacc.vv v16, v12, v24
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v16, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vadd.vv v16, v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %add = add <vscale x 16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; NODOT-NEXT: vwmulu.vv v12, v8, v10
+; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma
+; NODOT-NEXT: vwaddu.wv v16, v16, v12
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v16, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vadd.vv v16, v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, %b.zext
+ %add = add <vscale x 16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i32> %x) {
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v24, v10
+; NODOT-NEXT: vwmaccsu.vv v16, v12, v24
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v16, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotsu.vv v12, v8, v10
+; DOT-NEXT: vadd.vv v16, v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.zext
+ %add = add <vscale x 16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdot_vv_scalar_add(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, i32 %x) {
+; NODOT-LABEL: vqdot_vv_scalar_add:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a1, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v20, v10
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, a0
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_scalar_add:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, a0
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ %add = add i32 %sum, %x
+ ret i32 %add
+}
+
+define i32 @vqdotu_vv_scalar_add(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, i32 %x) {
+; NODOT-LABEL: vqdotu_vv_scalar_add:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a1, zero, e8, m2, ta, ma
+; NODOT-NEXT: vwmulu.vv v12, v8, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v8, a0
+; NODOT-NEXT: vsetvli zero, zero, e16, m4, ta, ma
+; NODOT-NEXT: vwredsumu.vs v8, v12, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_scalar_add:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, a0
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.zext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.zext, %b.zext
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ %add = add i32 %sum, %x
+ ret i32 %add
+}
+
+define i32 @vqdotsu_vv_scalar_add(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, i32 %x) {
+; NODOT-LABEL: vqdotsu_vv_scalar_add:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a1, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vzext.vf2 v20, v10
+; NODOT-NEXT: vwmulsu.vv v8, v16, v20
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, a0
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_scalar_add:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a1, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotsu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, a0
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.zext = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.zext
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %mul)
+ %add = add i32 %sum, %x
+ ret i32 %add
+}
+
+define i32 @vqdot_vv_split(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c, <vscale x 16 x i8> %d) {
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v20, v10
+; NODOT-NEXT: vsext.vf2 v24, v12
+; NODOT-NEXT: vsext.vf2 v28, v14
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vwmacc.vv v8, v24, v28
+; NODOT-NEXT: vsetvli zero, zero, e32, m8, ta, ma
+; NODOT-NEXT: vmv.s.x v16, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v16
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v16, 0
+; DOT-NEXT: vmv.v.i v18, 0
+; DOT-NEXT: vqdot.vv v16, v8, v10
+; DOT-NEXT: vqdot.vv v18, v12, v14
+; DOT-NEXT: vadd.vv v8, v16, v18
+; DOT-NEXT: vmv.s.x v10, zero
+; DOT-NEXT: vredsum.vs v8, v8, v10
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %c.sext = sext <vscale x 16 x i8> %c to <vscale x 16 x i32>
+ %d.sext = sext <vscale x 16 x i8> %d to <vscale x 16 x i32>
+ %mul2 = mul nuw nsw <vscale x 16 x i32> %c.sext, %d.sext
+ %add = add <vscale x 16 x i32> %mul, %mul2
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %add)
+ ret i32 %sum
+}
+
+
+define <vscale x 4 x i32> @vqdot_vv_partial_reduce(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: vqdot_vv_partial_reduce:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vsext.vf2 v20, v10
+; CHECK-NEXT: vwmul.vv v8, v16, v20
+; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT: vadd.vv v8, v14, v8
+; CHECK-NEXT: vadd.vv v10, v10, v12
+; CHECK-NEXT: vadd.vv v8, v10, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mul)
+ ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @vqdot_vv_partial_reduce2(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i32> %accum) {
+; CHECK-LABEL: vqdot_vv_partial_reduce2:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v24, v8
+; CHECK-NEXT: vsext.vf2 v28, v10
+; CHECK-NEXT: vwmul.vv v16, v24, v28
+; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-NEXT: vadd.vv v8, v18, v20
+; CHECK-NEXT: vadd.vv v10, v12, v16
+; CHECK-NEXT: vadd.vv v10, v22, v10
+; CHECK-NEXT: vadd.vv v8, v8, v10
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %accum, <vscale x 16 x i32> %mul)
+ ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 16 x i32> @vqdot_vv_partial_reduce3(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: vqdot_vv_partial_reduce3:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vsext.vf2 v20, v10
+; CHECK-NEXT: vwmul.vv v8, v16, v20
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nuw nsw <vscale x 16 x i32> %a.sext, %b.sext
+ %res = call <vscale x 16 x i32> @llvm.experimental.vector.partial.reduce.add.nvx8i32.nvx16i32.nvx16i32(<vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer)
+ ret <vscale x 16 x i32> %res
+}
>From aa122051a8405bf00922309004c76e49c79bd3f8 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 21 May 2025 08:46:37 -0700
Subject: [PATCH 2/2] [RISCV] Support scalable vectors in zvqdotq accumulator
folding
(This part is a missed optimization, not a correctness issue.)
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 +++++++++----
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 32 +++++++------------
2 files changed, 29 insertions(+), 26 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 73798b899e9ff..d69e04a9912a2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18644,7 +18644,7 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(N->getOpcode() == RISCVISD::ADD_VL);
+ assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
if (!N->getValueType(0).isVector())
return SDValue();
@@ -18652,9 +18652,11 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
SDValue Addend = N->getOperand(0);
SDValue DotOp = N->getOperand(1);
- SDValue AddPassthruOp = N->getOperand(2);
- if (!AddPassthruOp.isUndef())
- return SDValue();
+ if (N->getOpcode() == RISCVISD::ADD_VL) {
+ SDValue AddPassthruOp = N->getOperand(2);
+ if (!AddPassthruOp.isUndef())
+ return SDValue();
+ }
auto IsVqdotqOpc = [](unsigned Opc) {
switch (Opc) {
@@ -18673,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
if (!IsVqdotqOpc(DotOp.getOpcode()))
return SDValue();
- SDValue AddMask = N->getOperand(3);
- SDValue AddVL = N->getOperand(4);
+ auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (N->getOpcode() == ISD::ADD) {
+ SDLoc DL(N);
+ return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+ Subtarget);
+ }
+ return std::make_pair(N->getOperand(3), N->getOperand(4));
+ }(N, DAG, Subtarget);
SDValue MulVL = DotOp.getOperand(4);
if (AddVL != MulVL)
@@ -19312,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return V;
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
return V;
+ if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+ return V;
return performADDCombine(N, DCI, Subtarget);
}
case ISD::SUB: {
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d811cdb5e444d..a56ef0cd75d6a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -314,12 +314,10 @@ define i32 @vqdot_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale
; DOT-LABEL: vqdot_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT-NEXT: vmv.v.i v12, 0
-; DOT-NEXT: vqdot.vv v12, v8, v10
-; DOT-NEXT: vadd.vv v16, v12, v16
-; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vmv.s.x v12, zero
+; DOT-NEXT: vqdot.vv v16, v8, v10
; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
-; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vredsum.vs v8, v16, v12
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
@@ -347,12 +345,10 @@ define i32 @vqdotu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscal
; DOT-LABEL: vqdotu_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT-NEXT: vmv.v.i v12, 0
-; DOT-NEXT: vqdotu.vv v12, v8, v10
-; DOT-NEXT: vadd.vv v16, v12, v16
-; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vmv.s.x v12, zero
+; DOT-NEXT: vqdotu.vv v16, v8, v10
; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
-; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vredsum.vs v8, v16, v12
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
@@ -380,12 +376,10 @@ define i32 @vqdotsu_vv_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vsca
; DOT-LABEL: vqdotsu_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT-NEXT: vmv.v.i v12, 0
-; DOT-NEXT: vqdotsu.vv v12, v8, v10
-; DOT-NEXT: vadd.vv v16, v12, v16
-; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vmv.s.x v12, zero
+; DOT-NEXT: vqdotsu.vv v16, v8, v10
; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
-; DOT-NEXT: vredsum.vs v8, v16, v8
+; DOT-NEXT: vredsum.vs v8, v16, v12
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
@@ -510,12 +504,10 @@ define i32 @vqdot_vv_split(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
; DOT-NEXT: vmv.v.i v16, 0
-; DOT-NEXT: vmv.v.i v18, 0
; DOT-NEXT: vqdot.vv v16, v8, v10
-; DOT-NEXT: vqdot.vv v18, v12, v14
-; DOT-NEXT: vadd.vv v8, v16, v18
-; DOT-NEXT: vmv.s.x v10, zero
-; DOT-NEXT: vredsum.vs v8, v8, v10
+; DOT-NEXT: vqdot.vv v16, v12, v14
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v16, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
More information about the llvm-commits
mailing list