[llvm] 443cdd0 - [RISCV] Fix a bug in partial.reduce lowering for zvqdotq .vx forms (#142185)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 11:05:46 PDT 2025
Author: Philip Reames
Date: 2025-05-30T11:05:43-07:00
New Revision: 443cdd0b48b850b9f2d7cb93f9cc0ba8d5ac4827
URL: https://github.com/llvm/llvm-project/commit/443cdd0b48b850b9f2d7cb93f9cc0ba8d5ac4827
DIFF: https://github.com/llvm/llvm-project/commit/443cdd0b48b850b9f2d7cb93f9cc0ba8d5ac4827.diff
LOG: [RISCV] Fix a bug in partial.reduce lowering for zvqdotq .vx forms (#142185)
I'd missed a bitcast in the lowering. Unfortunately, that bitcast
happens to be semantically required here as the partial_reduce_* source
expects an i8 element type, but the pseudos and patterns expect an i32
element type.
This appears to only influence the .vx matching from the cases I've
found so far, and LV does not yet generate anything which will exercise
this. The reduce path (instead of the partial.reduce one) used by SLP
currently manually constructs the i32 value, and then goes directly to
the pseudo's with their i32 arguments, not the partial_reduce nodes.
We're basically loosing the .vx matching on this path until we teach
splat matching to be able to manually splat the i8 value into an i32 via
LUI/ADDI.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f2311e94252e9..b7fd0c93fa93f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8412,13 +8412,18 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
assert(ArgVT == B.getSimpleValueType() &&
ArgVT.getVectorElementType() == MVT::i8);
+ // The zvqdotq pseudos are defined with sources and destination both
+ // being i32. This cast is needed for correctness to avoid incorrect
+ // .vx matching of i8 splats.
+ A = DAG.getBitcast(VT, A);
+ B = DAG.getBitcast(VT, B);
+
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
- MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
- A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
- B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
+ A = convertToScalableVector(ContainerVT, A, DAG, Subtarget);
+ B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
}
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 5e9bbe6c1ebce..0237faea9efb7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -598,7 +598,6 @@ entry:
ret <1 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
; NODOT-LABEL: vqdotu_vx_partial_reduce:
; NODOT: # %bb.0: # %entry
@@ -618,10 +617,13 @@ define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
;
; DOT-LABEL: vqdotu_vx_partial_reduce:
; DOT: # %bb.0: # %entry
-; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; DOT-NEXT: vmv.s.x v9, zero
; DOT-NEXT: li a0, 128
-; DOT-NEXT: vqdotu.vx v9, v8, a0
+; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT: vmv.v.x v10, a0
+; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vqdotu.vv v9, v8, v10
; DOT-NEXT: vmv1r.v v8, v9
; DOT-NEXT: ret
entry:
@@ -631,7 +633,6 @@ entry:
ret <1 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
; NODOT-LABEL: vqdot_vx_partial_reduce:
; NODOT: # %bb.0: # %entry
@@ -652,10 +653,13 @@ define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
;
; DOT-LABEL: vqdot_vx_partial_reduce:
; DOT: # %bb.0: # %entry
-; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; DOT-NEXT: vmv.s.x v9, zero
; DOT-NEXT: li a0, 128
-; DOT-NEXT: vqdot.vx v9, v8, a0
+; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT: vmv.v.x v10, a0
+; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vqdot.vv v9, v8, v10
; DOT-NEXT: vmv1r.v v8, v9
; DOT-NEXT: ret
entry:
@@ -1372,7 +1376,6 @@ entry:
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <4 x i32> @partial_of_sext(<16 x i8> %a) {
; NODOT-LABEL: partial_of_sext:
; NODOT: # %bb.0: # %entry
@@ -1393,10 +1396,11 @@ define <4 x i32> @partial_of_sext(<16 x i8> %a) {
;
; DOT-LABEL: partial_of_sext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: li a0, 1
-; DOT-NEXT: vqdot.vx v9, v8, a0
+; DOT-NEXT: vqdot.vv v9, v8, v10
; DOT-NEXT: vmv.v.v v8, v9
; DOT-NEXT: ret
entry:
@@ -1405,7 +1409,6 @@ entry:
ret <4 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <4 x i32> @partial_of_zext(<16 x i8> %a) {
; NODOT-LABEL: partial_of_zext:
; NODOT: # %bb.0: # %entry
@@ -1426,10 +1429,11 @@ define <4 x i32> @partial_of_zext(<16 x i8> %a) {
;
; DOT-LABEL: partial_of_zext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: li a0, 1
-; DOT-NEXT: vqdotu.vx v9, v8, a0
+; DOT-NEXT: vqdotu.vv v9, v8, v10
; DOT-NEXT: vmv.v.v v8, v9
; DOT-NEXT: ret
entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 2bd2ef2878fd5..d0fc915a0d07e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -957,3 +957,56 @@ entry:
%res = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 1 x i32> zeroinitializer, <vscale x 4 x i32> %mul)
ret <vscale x 1 x i32> %res
}
+
+
+define <vscale x 4 x i32> @partial_of_sext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vsext.vf4 v16, v8
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v22, v16
+; NODOT-NEXT: vadd.vv v10, v18, v20
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v12
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @partial_of_zext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vzext.vf4 v16, v8
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v22, v16
+; NODOT-NEXT: vadd.vv v10, v18, v20
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v12
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %res
+}
More information about the llvm-commits
mailing list