[llvm] [RISCV] Fold add_vl into accumulator operand of vqdot* (PR #139484)
via llvm-commits
llvm-commits at lists.llvm.org
Sun May 11 17:23:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.
---
Full diff: https://github.com/llvm/llvm-project/pull/139484.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+70-3)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+12-17)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c53550ea3b23b..93aabdc004b42 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(Opc, DL, VT, Ops);
}
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
- ISD::MemIndexType &IndexType,
- RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
+ assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+ if (!N->getValueType(0).isVector())
+ return SDValue();
+
+ SDValue Addend = N->getOperand(0);
+ SDValue DotOp = N->getOperand(1);
+
+ SDValue AddPassthruOp = N->getOperand(2);
+ if (!AddPassthruOp.isUndef())
+ return SDValue();
+
+ auto IsVqdotqOpc = [](unsigned Opc) {
+ switch (Opc) {
+ case RISCVISD::VQDOT_VL:
+ case RISCVISD::VQDOTU_VL:
+ case RISCVISD::VQDOTSU_VL:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ if (!IsVqdotqOpc(DotOp.getOpcode()))
+ std::swap(Addend, DotOp);
+
+ if (!IsVqdotqOpc(DotOp.getOpcode()))
+ return SDValue();
+
+ SDValue AddMask = N->getOperand(3);
+ SDValue AddVL = N->getOperand(4);
+
+ SDValue MulVL = DotOp.getOperand(4);
+ if (AddVL != MulVL)
+ return SDValue();
+
+ if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+ AddMask.getOperand(0) != MulVL)
+ return SDValue();
+
+ SDValue AccumOp = DotOp.getOperand(2);
+ bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+ // Peak through fixed to scalable
+ if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
+ AccumOp.getOperand(0).isUndef())
+ IsNullAdd =
+ ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ // The manual constant folding is required, this case is not constant folded
+ // or combined.
+ if (!IsNullAdd)
+ Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
+ DAG.getUNDEF(VT), AddMask, AddVL);
+
+ SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
+ DotOp.getOperand(3), DotOp->getOperand(4)};
+ return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
+}
+
+static bool
+legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+ ISD::MemIndexType &IndexType,
+ RISCVTargetLowering::DAGCombinerInfo &DCI) {
if (!DCI.isBeforeLegalize())
return false;
@@ -19582,6 +19647,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::ADD_VL:
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
+ if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+ return V;
return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index e5546ad404c1b..ff61ef82176e6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
; DOT-LABEL: vqdot_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT: vmv.v.i v10, 0
-; DOT-NEXT: vqdot.vv v10, v8, v9
-; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdot.vv v16, v8, v9
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.v.v v12, v16
; DOT-NEXT: vmv.s.x v8, zero
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
; DOT-LABEL: vqdotu_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT: vmv.v.i v10, 0
-; DOT-NEXT: vqdotu.vv v10, v8, v9
-; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotu.vv v16, v8, v9
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.v.v v12, v16
; DOT-NEXT: vmv.s.x v8, zero
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
; DOT-LABEL: vqdotsu_vv_accum:
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT: vmv.v.i v10, 0
-; DOT-NEXT: vqdotsu.vv v10, v8, v9
-; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotsu.vv v16, v8, v9
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.v.v v12, v16
; DOT-NEXT: vmv.s.x v8, zero
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
; DOT: # %bb.0: # %entry
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v12, 0
-; DOT-NEXT: vmv.v.i v13, 0
; DOT-NEXT: vqdot.vv v12, v8, v9
-; DOT-NEXT: vqdot.vv v13, v10, v11
-; DOT-NEXT: vadd.vv v8, v12, v13
-; DOT-NEXT: vmv.s.x v9, zero
-; DOT-NEXT: vredsum.vs v8, v8, v9
+; DOT-NEXT: vqdot.vv v12, v10, v11
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
``````````
</details>
https://github.com/llvm/llvm-project/pull/139484
More information about the llvm-commits
mailing list