[llvm] [SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes (PR #140075)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Wed May 21 03:40:22 PDT 2025
================
@@ -29518,37 +29533,58 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
-/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
+/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
+/// the following pattern is emitted:
+/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
+/// NTy/2))))
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
+ bool Scalable = Op.getValueType().isScalableVector();
+ if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+ if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+ return SDValue();
+
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
- assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
- SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
- DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+ assert((Scalable && ResultVT == MVT::nxv2i64 &&
+ LHS.getValueType() == MVT::nxv16i8) ||
+ (!Scalable && ResultVT == MVT::v2i64 &&
+ LHS.getValueType() == MVT::v16i8));
+
+ EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+ SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
+ DAG.getConstant(0, DL, DotVT), LHS, RHS);
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
- if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+ if (Scalable &&
+ (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
}
- unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
- unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
- auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
- auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
- auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
- return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+ // Fold (nx)v4i32 into (nx)v2i64
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+ if (IsUnsigned) {
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+ } else {
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+ }
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
----------------
MacDue wrote:
It looks like this regresses some cases (based on the failed tests), I think you could do something like:
```suggestion
// For SVE do `Acc + (Lo + Hi)` (this avoids an extra add in some cases)
// For Neon do: `Hi + (Acc + Lo)` (this matchs uaddw(2))
auto [Add1Op, Add2Op] = Scalable
? std::make_pair(DotNodeHi, Acc)
: std::make_pair(Acc, DotNodeHi);
SDValue Add = DAG.getNode(ISD::ADD, DL, ResultVT, DotNodeLo, Add1Op);
return DAG.getNode(ISD::ADD, DL, ResultVT, Add, Add2Op);
```
Or maybe adjust the later patterns.
https://github.com/llvm/llvm-project/pull/140075
More information about the llvm-commits
mailing list