[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 20 08:37:02 PST 2025
================
@@ -22011,34 +22010,25 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
- bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+ SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ bool Scalable = Op0->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- SDLoc DL(N);
-
- SDValue Op2 = N->getOperand(2);
- unsigned Op2Opcode = Op2->getOpcode();
+ unsigned Op1Opcode = Op1->getOpcode();
SDValue MulOpLHS, MulOpRHS;
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op2Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op2->getOperand(0);
- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
- } else if (Op2Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op2->getOperand(0);
- SDValue ExtMulOpRHS = Op2->getOperand(1);
+ if (ISD::isExtOpcode(Op1Opcode)) {
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
+ MulOpLHS = Op1->getOperand(0);
+ MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
+ } else if (Op1Opcode == ISD::MUL) {
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
----------------
MacDue wrote:
I'm a little confused by what this fold is doing now. `PARTIAL_REDUCE_MLA` is `acc += partial_reduce(op1 * op2)`. But this combine is also checking for another `MUL` node. I think this would be simpler to follow if there was a fold that took:
```
mul = op1 * op2
ret = PARTIAL_REDUCE_MLA(mul * splat (1))
```
And folded it to:
```
ret = PARTIAL_REDUCE_MLA(op1, op2)
```
Then the dot product lowering can be simpler (since it does not need to worry about the `MUL` node).
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list