[llvm] [RISCV] Fold add_vl into accumulator operand of vqdot* (PR #139484)

via llvm-commits llvm-commits at lists.llvm.org
Sun May 11 17:23:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead.  For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.

---
Full diff: https://github.com/llvm/llvm-project/pull/139484.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+70-3) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+12-17) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c53550ea3b23b..93aabdc004b42 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(Opc, DL, VT, Ops);
 }
 
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
-                                           ISD::MemIndexType &IndexType,
-                                           RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  if (!N->getValueType(0).isVector())
+    return SDValue();
+
+  SDValue Addend = N->getOperand(0);
+  SDValue DotOp = N->getOperand(1);
+
+  SDValue AddPassthruOp = N->getOperand(2);
+  if (!AddPassthruOp.isUndef())
+    return SDValue();
+
+  auto IsVqdotqOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VQDOT_VL:
+    case RISCVISD::VQDOTU_VL:
+    case RISCVISD::VQDOTSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    std::swap(Addend, DotOp);
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+
+  SDValue MulVL = DotOp.getOperand(4);
+  if (AddVL != MulVL)
+    return SDValue();
+
+  if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+      AddMask.getOperand(0) != MulVL)
+    return SDValue();
+
+  SDValue AccumOp = DotOp.getOperand(2);
+  bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+  // Peak through fixed to scalable
+  if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
+      AccumOp.getOperand(0).isUndef())
+    IsNullAdd =
+        ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  // The manual constant folding is required, this case is not constant folded
+  // or combined.
+  if (!IsNullAdd)
+    Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
+                         DAG.getUNDEF(VT), AddMask, AddVL);
+
+  SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
+                   DotOp.getOperand(3), DotOp->getOperand(4)};
+  return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
+}
+
+static bool
+legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+                               ISD::MemIndexType &IndexType,
+                               RISCVTargetLowering::DAGCombinerInfo &DCI) {
   if (!DCI.isBeforeLegalize())
     return false;
 
@@ -19582,6 +19647,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::ADD_VL:
     if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
+    if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+      return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index e5546ad404c1b..ff61ef82176e6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdot_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdot.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdot.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotsu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotsu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotsu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.v.i v12, 0
-; DOT-NEXT:    vmv.v.i v13, 0
 ; DOT-NEXT:    vqdot.vv v12, v8, v9
-; DOT-NEXT:    vqdot.vv v13, v10, v11
-; DOT-NEXT:    vadd.vv v8, v12, v13
-; DOT-NEXT:    vmv.s.x v9, zero
-; DOT-NEXT:    vredsum.vs v8, v8, v9
+; DOT-NEXT:    vqdot.vv v12, v10, v11
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
 ; DOT-NEXT:    vmv.x.s a0, v8
 ; DOT-NEXT:    ret
 entry:

``````````

</details>


https://github.com/llvm/llvm-project/pull/139484


More information about the llvm-commits mailing list