[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 26 11:32:39 PST 2025
================
@@ -22011,138 +22010,188 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
- bool Scalable = N->getValueType(0).isScalableVector();
- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+SDValue tryCombinePartialReduceMLAMulOp(SDValue &Op0, SDValue &Op1,
+ SDValue &Op2, SelectionDAG &DAG,
+ SDLoc &DL) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat (1))
+ // into PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))
+ if (Op1->getOpcode() != ISD::MUL)
return SDValue();
- if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();
- SDLoc DL(N);
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
- SDValue Op2 = N->getOperand(2);
unsigned Op2Opcode = Op2->getOpcode();
- SDValue MulOpLHS, MulOpRHS;
- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op2Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op2->getOperand(0);
- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
- } else if (Op2Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op2->getOperand(0);
- SDValue ExtMulOpRHS = Op2->getOperand(1);
-
- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
- return SDValue();
+ if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
+ !isOneConstant(Op2->getOperand(0)))
----------------
sdesmalen-arm wrote:
Please use `ISD::isConstantSplatVector` instead, and then check that the APInt is one using `isOne()`.
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list