[llvm] [RISCV] Extend zvqdot matching to handle reduction trees (PR #138965)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 7 13:54:56 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction.
The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.
---
Full diff: https://github.com/llvm/llvm-project/pull/138965.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+43)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+102-47)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 756c563f0194d..a4d778c054ecf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18131,6 +18131,29 @@ static MVT getQDOTXResultType(MVT OpVT) {
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
}
+/// Given fixed length vectors A and B with equal element types, but possibly
+/// different number of elements, return A + B where either A or B is zero
+/// padded to the larger number of elements.
+static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
+ SelectionDAG &DAG) {
+ // NOTE: Manually doing the extract/add/insert scheme produces
+ // significantly better coegen than the naive pad with zeros
+ // and add scheme.
+ EVT AVT = A.getValueType();
+ EVT BVT = B.getValueType();
+ assert(AVT.getVectorElementType() == BVT.getVectorElementType());
+ if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+ std::swap(A, B);
+ std::swap(AVT, BVT);
+ }
+
+ SDValue BPart = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AVT, B,
+ DAG.getVectorIdxConstant(0, DL));
+ SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, BVT, B, Res,
+ DAG.getVectorIdxConstant(0, DL));
+}
+
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
@@ -18142,6 +18165,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
+ // Recurse through adds (since generic dag canonicalizes to that
+ // form).
+ if (InVec->getOpcode() == ISD::ADD) {
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
+ SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
+ if (AOpt || BOpt) {
+ if (AOpt)
+ A = AOpt;
+ if (BOpt)
+ B = BOpt;
+ // From here, we're doing A + B with mixed types, implicitly zero
+ // padded to the wider type. Note that we *don't* need the result
+ // type to be the original VT, and in fact prefer narrower ones
+ // if possible.
+ return getZeroPaddedAdd(DL, A, B, DAG);
+ }
+ }
+
// reduce (zext a) <--> reduce (mul zext a. zext 1)
// reduce (sext a) <--> reduce (mul sext a. sext 1)
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index edc9886abc3b9..e5546ad404c1b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -299,17 +299,31 @@ entry:
}
define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vsext.vf2 v16, v9
-; CHECK-NEXT: vwmacc.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v16, v9
+; NODOT-NEXT: vwmacc.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
}
define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwaddu.wv v12, v12, v10
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwaddu.wv v12, v12, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
}
define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf2 v16, v9
-; CHECK-NEXT: vwmaccsu.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vzext.vf2 v16, v9
+; NODOT-NEXT: vwmaccsu.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
}
define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
-; CHECK-LABEL: vqdot_vv_split:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vsext.vf2 v16, v10
-; CHECK-NEXT: vsext.vf2 v18, v11
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vwmacc.vv v8, v16, v18
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vsext.vf2 v16, v10
+; NODOT-NEXT: vsext.vf2 v18, v11
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vwmacc.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vmv.v.i v13, 0
+; DOT-NEXT: vqdot.vv v12, v8, v9
+; DOT-NEXT: vqdot.vv v13, v10, v11
+; DOT-NEXT: vadd.vv v8, v12, v13
+; DOT-NEXT: vmv.s.x v9, zero
+; DOT-NEXT: vredsum.vs v8, v8, v9
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/138965
More information about the llvm-commits
mailing list