[llvm] 7e64ade - [RISCV] Extend zvqdot matching to handle reduction trees (#138965)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 9 08:10:02 PDT 2025
Author: Philip Reames
Date: 2025-05-09T08:09:46-07:00
New Revision: 7e64ade2ef1af72db235f6d76ecd319abc09690e
URL: https://github.com/llvm/llvm-project/commit/7e64ade2ef1af72db235f6d76ecd319abc09690e
DIFF: https://github.com/llvm/llvm-project/commit/7e64ade2ef1af72db235f6d76ecd319abc09690e.diff
LOG: [RISCV] Extend zvqdot matching to handle reduction trees (#138965)
Now that we have matching for vqdot in it's basic variants, we can
extend the matcher to handle reduction trees instead of individual
reductions. This is important as we canonicalize reductions by
performing a tree in the vector domain before the root reduction
instruction.
The particular approach taken here has the unfortunate implication that
non-matches visit the entire reduction tree once for each time the
reduction root is visited in DAG. While conceptually problematic for
compile time, this is probably fine in practice as we should only visit
the root once per pass of DAGCombine. I don't really see a better
solution - suggestions welcome.
---------
Co-authored-by: Luke Lau <luke_lau at icloud.com>
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 36d144ea92b22..fd1d4d439fd7b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18054,6 +18054,27 @@ static MVT getQDOTXResultType(MVT OpVT) {
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
}
+/// Given fixed length vectors A and B with equal element types, but possibly
+///
diff erent number of elements, return A + B where either A or B is zero
+/// padded to the larger number of elements.
+static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
+ SelectionDAG &DAG) {
+ // NOTE: Manually doing the extract/add/insert scheme produces
+ // significantly better codegen than the naive pad with zeros
+ // and add scheme.
+ EVT AVT = A.getValueType();
+ EVT BVT = B.getValueType();
+ assert(AVT.getVectorElementType() == BVT.getVectorElementType());
+ if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+ std::swap(A, B);
+ std::swap(AVT, BVT);
+ }
+
+ SDValue BPart = DAG.getExtractSubvector(DL, AVT, B, 0);
+ SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
+ return DAG.getInsertSubvector(DL, B, Res, 0);
+}
+
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
@@ -18065,6 +18086,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
+ // Recurse through adds (since generic dag canonicalizes to that
+ // form). TODO: Handle disjoint or here.
+ if (InVec->getOpcode() == ISD::ADD) {
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
+ SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
+ if (AOpt || BOpt) {
+ if (AOpt)
+ A = AOpt;
+ if (BOpt)
+ B = BOpt;
+ // From here, we're doing A + B with mixed types, implicitly zero
+ // padded to the wider type. Note that we *don't* need the result
+ // type to be the original VT, and in fact prefer narrower ones
+ // if possible.
+ return getZeroPaddedAdd(DL, A, B, DAG);
+ }
+ }
+
// reduce (zext a) <--> reduce (mul zext a. zext 1)
// reduce (sext a) <--> reduce (mul sext a. sext 1)
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index edc9886abc3b9..e5546ad404c1b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -299,17 +299,31 @@ entry:
}
define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vsext.vf2 v16, v9
-; CHECK-NEXT: vwmacc.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v16, v9
+; NODOT-NEXT: vwmacc.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
}
define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwaddu.wv v12, v12, v10
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwaddu.wv v12, v12, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
}
define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf2 v16, v9
-; CHECK-NEXT: vwmaccsu.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vzext.vf2 v16, v9
+; NODOT-NEXT: vwmaccsu.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
}
define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
-; CHECK-LABEL: vqdot_vv_split:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vsext.vf2 v16, v10
-; CHECK-NEXT: vsext.vf2 v18, v11
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vwmacc.vv v8, v16, v18
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vsext.vf2 v16, v10
+; NODOT-NEXT: vsext.vf2 v18, v11
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vwmacc.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vmv.v.i v13, 0
+; DOT-NEXT: vqdot.vv v12, v8, v9
+; DOT-NEXT: vqdot.vv v13, v10, v11
+; DOT-NEXT: vadd.vv v8, v12, v13
+; DOT-NEXT: vmv.s.x v9, zero
+; DOT-NEXT: vredsum.vs v8, v8, v9
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
More information about the llvm-commits
mailing list