[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 23 06:34:09 PST 2025
================
@@ -22011,34 +22010,25 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
- bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+ SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ bool Scalable = Op0->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- SDLoc DL(N);
-
- SDValue Op2 = N->getOperand(2);
- unsigned Op2Opcode = Op2->getOpcode();
+ unsigned Op1Opcode = Op1->getOpcode();
SDValue MulOpLHS, MulOpRHS;
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op2Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op2->getOperand(0);
- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
- } else if (Op2Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op2->getOperand(0);
- SDValue ExtMulOpRHS = Op2->getOperand(1);
+ if (ISD::isExtOpcode(Op1Opcode)) {
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
+ MulOpLHS = Op1->getOperand(0);
+ MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
+ } else if (Op1Opcode == ISD::MUL) {
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
----------------
JamesChesterman wrote:
Done. Now there is a separate function for lowering:
`PARTIAL_REDUCE_MLA (Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), SPLAT (1))`
To:
`PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))`
The function `tryCombineToDotProduct` can then handle this pattern (and removes the extends) as well as the pattern for `PARTIAL_REDUCE_MLA(Acc, EXT(Op), SPLAT (1))`.
I've added comments in the relevant places detailing what happens.
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list