[llvm] 6408291 - [RISCV] Fold add_vl into accumulator operand of vqdot* (#139484)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 15:04:48 PDT 2025


Author: Philip Reames
Date: 2025-05-12T15:04:45-07:00
New Revision: 64082912a500d004c53ad1b3425098b495572663

URL: https://github.com/llvm/llvm-project/commit/64082912a500d004c53ad1b3425098b495572663
DIFF: https://github.com/llvm/llvm-project/commit/64082912a500d004c53ad1b3425098b495572663.diff

LOG: [RISCV] Fold add_vl into accumulator operand of vqdot* (#139484)

If we have a add_vl following a vqdot* instruction, we can move the add
before the vqdot instead. For cases where the prior accumulator was
zero, we can fold the add into the vqdot* instruction entirely. This
directly parallels the folding we do for multiply add variants.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 158a3afdb864c..0f93bd9fef060 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18408,7 +18408,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
                                const RISCVSubtarget &Subtarget) {
-
   assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
 
   if (N->getValueType(0).isFixedLengthVector())
@@ -18472,9 +18471,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(Opc, DL, VT, Ops);
 }
 
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
-                                           ISD::MemIndexType &IndexType,
-                                           RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  if (!N->getValueType(0).isVector())
+    return SDValue();
+
+  SDValue Addend = N->getOperand(0);
+  SDValue DotOp = N->getOperand(1);
+
+  SDValue AddPassthruOp = N->getOperand(2);
+  if (!AddPassthruOp.isUndef())
+    return SDValue();
+
+  auto IsVqdotqOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VQDOT_VL:
+    case RISCVISD::VQDOTU_VL:
+    case RISCVISD::VQDOTSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    std::swap(Addend, DotOp);
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+
+  SDValue MulVL = DotOp.getOperand(4);
+  if (AddVL != MulVL)
+    return SDValue();
+
+  if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+      AddMask.getOperand(0) != MulVL)
+    return SDValue();
+
+  SDValue AccumOp = DotOp.getOperand(2);
+  bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+  // Peek through fixed to scalable
+  if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
+      AccumOp.getOperand(0).isUndef())
+    IsNullAdd =
+        ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  // The manual constant folding is required, this case is not constant folded
+  // or combined.
+  if (!IsNullAdd)
+    Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
+                         DAG.getUNDEF(VT), AddMask, AddVL);
+
+  SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
+                   DotOp.getOperand(3), DotOp->getOperand(4)};
+  return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
+}
+
+static bool
+legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+                               ISD::MemIndexType &IndexType,
+                               RISCVTargetLowering::DAGCombinerInfo &DCI) {
   if (!DCI.isBeforeLegalize())
     return false;
 
@@ -19595,6 +19659,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::ADD_VL:
     if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
+    if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+      return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index e5546ad404c1b..ff61ef82176e6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdot_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdot.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdot.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotsu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotsu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotsu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.v.i v12, 0
-; DOT-NEXT:    vmv.v.i v13, 0
 ; DOT-NEXT:    vqdot.vv v12, v8, v9
-; DOT-NEXT:    vqdot.vv v13, v10, v11
-; DOT-NEXT:    vadd.vv v8, v12, v13
-; DOT-NEXT:    vmv.s.x v9, zero
-; DOT-NEXT:    vredsum.vs v8, v8, v9
+; DOT-NEXT:    vqdot.vv v12, v10, v11
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
 ; DOT-NEXT:    vmv.x.s a0, v8
 ; DOT-NEXT:    ret
 entry:


        


More information about the llvm-commits mailing list