[llvm] [RISCV] Fold add_vl into accumulator operand of vqdot* (PR #139484)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon May 12 13:01:18 PDT 2025
================
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(Opc, DL, VT, Ops);
}
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
- ISD::MemIndexType &IndexType,
- RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
+ assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+ if (!N->getValueType(0).isVector())
+ return SDValue();
+
+ SDValue Addend = N->getOperand(0);
+ SDValue DotOp = N->getOperand(1);
+
+ SDValue AddPassthruOp = N->getOperand(2);
+ if (!AddPassthruOp.isUndef())
+ return SDValue();
+
+ auto IsVqdotqOpc = [](unsigned Opc) {
+ switch (Opc) {
+ case RISCVISD::VQDOT_VL:
+ case RISCVISD::VQDOTU_VL:
+ case RISCVISD::VQDOTSU_VL:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ if (!IsVqdotqOpc(DotOp.getOpcode()))
+ std::swap(Addend, DotOp);
+
+ if (!IsVqdotqOpc(DotOp.getOpcode()))
+ return SDValue();
+
+ SDValue AddMask = N->getOperand(3);
+ SDValue AddVL = N->getOperand(4);
+
+ SDValue MulVL = DotOp.getOperand(4);
+ if (AddVL != MulVL)
+ return SDValue();
+
+ if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+ AddMask.getOperand(0) != MulVL)
+ return SDValue();
+
+ SDValue AccumOp = DotOp.getOperand(2);
+ bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+ // Peak through fixed to scalable
----------------
topperc wrote:
Peek*
https://github.com/llvm/llvm-project/pull/139484
More information about the llvm-commits
mailing list