[llvm] [RISCV] Fold add_vl into accumulator operand of vqdot* (PR #139484)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 13:01:18 PDT 2025


================
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(Opc, DL, VT, Ops);
 }
 
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
-                                           ISD::MemIndexType &IndexType,
-                                           RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  if (!N->getValueType(0).isVector())
+    return SDValue();
+
+  SDValue Addend = N->getOperand(0);
+  SDValue DotOp = N->getOperand(1);
+
+  SDValue AddPassthruOp = N->getOperand(2);
+  if (!AddPassthruOp.isUndef())
+    return SDValue();
+
+  auto IsVqdotqOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VQDOT_VL:
+    case RISCVISD::VQDOTU_VL:
+    case RISCVISD::VQDOTSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    std::swap(Addend, DotOp);
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+
+  SDValue MulVL = DotOp.getOperand(4);
+  if (AddVL != MulVL)
+    return SDValue();
+
+  if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+      AddMask.getOperand(0) != MulVL)
+    return SDValue();
+
+  SDValue AccumOp = DotOp.getOperand(2);
+  bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+  // Peak through fixed to scalable
----------------
topperc wrote:

Peek*

https://github.com/llvm/llvm-project/pull/139484


More information about the llvm-commits mailing list