[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 26 11:32:41 PST 2025
================
@@ -22011,138 +22010,188 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
- bool Scalable = N->getValueType(0).isScalableVector();
- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+SDValue tryCombinePartialReduceMLAMulOp(SDValue &Op0, SDValue &Op1,
+ SDValue &Op2, SelectionDAG &DAG,
+ SDLoc &DL) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat (1))
+ // into PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))
+ if (Op1->getOpcode() != ISD::MUL)
return SDValue();
- if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();
- SDLoc DL(N);
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
- SDValue Op2 = N->getOperand(2);
unsigned Op2Opcode = Op2->getOpcode();
- SDValue MulOpLHS, MulOpRHS;
- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op2Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op2->getOperand(0);
- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
- } else if (Op2Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op2->getOperand(0);
- SDValue ExtMulOpRHS = Op2->getOperand(1);
-
- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
- return SDValue();
+ if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
+ !isOneConstant(Op2->getOperand(0)))
+ return SDValue();
- MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
- MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, Op0->getValueType(0), Op0,
+ ExtMulOpLHS, ExtMulOpRHS);
+}
- MulOpLHS = ExtMulOpLHS->getOperand(0);
- MulOpRHS = ExtMulOpRHS->getOperand(0);
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &ExtOp1, SDValue &ExtOp2,
+ SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ bool Scalable = Op0->getValueType(0).isScalableVector();
+ if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+ if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+ return SDValue();
- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+ unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+ unsigned ExtOp2Opcode = ExtOp2->getOpcode();
+ SDValue Op1, Op2;
+ bool Op1IsSigned, Op2IsSigned;
+ if (!ISD::isExtOpcode(ExtOp1Opcode))
+ return SDValue();
+ Op1 = ExtOp1->getOperand(0);
+ EVT SrcVT = Op1.getValueType();
+
+ if ((ExtOp2Opcode == ISD::SPLAT_VECTOR ||
+ ExtOp2Opcode == ISD::BUILD_VECTOR) &&
+ isOneConstant(ExtOp2.getOperand(0))) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
+ Op1IsSigned = Op2IsSigned = (ExtOp1Opcode == ISD::SIGN_EXTEND);
+ // Can only do this because it's a splat vector of constant 1
+ Op2 = DAG.getAnyExtOrTrunc(ExtOp2, DL, SrcVT);
+ } else if (ISD::isExtOpcode(ExtOp2Opcode)) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Ext(Op2)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Op2)
+ Op2 = ExtOp2->getOperand(0);
+ Op1IsSigned = ExtOp1Opcode == ISD::SIGN_EXTEND;
+ Op2IsSigned = ExtOp2Opcode == ISD::SIGN_EXTEND;
+ if (SrcVT != Op2.getValueType())
return SDValue();
- } else
+ } else {
return SDValue();
+ }
- SDValue Acc = N->getOperand(1);
- EVT ReducedVT = N->getValueType(0);
- EVT MulSrcVT = MulOpLHS.getValueType();
+ SDValue Acc = Op0;
+ EVT ReducedVT = Acc->getValueType(0);
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ if (!(ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && SrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && SrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && SrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && SrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
- unsigned Opcode = 0;
- if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
+ unsigned DotOpcode = Op1IsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ if (Op1IsSigned != Op2IsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
- bool Scalable = N->getValueType(0).isScalableVT();
+ bool Scalable = ReducedVT.isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
- Opcode = AArch64ISD::USDOT;
- // USDOT expects the signed operand to be last
- if (!MulOpRHSIsSigned)
- std::swap(MulOpLHS, MulOpRHS);
- } else
- Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ if (!Op2IsSigned)
+ std::swap(Op1, Op2);
+ DotOpcode = AArch64ISD::USDOT;
+ // Lower usdot patterns here because legalisation would attempt to split it
+ // unless exts are removed. But, removing the exts would lose the
+ // information about whether each operand is signed.
+ if ((ReducedVT != MVT::nxv4i64 || SrcVT != MVT::nxv16i8) &&
+ (ReducedVT != MVT::v4i64 || SrcVT != MVT::v16i8))
+ return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, Op1, Op2);
+ }
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
- // product followed by a zero / sign extension
- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ // product followed by a zero / sign extension. Need to lower this here
+ // because legalisation would attempt to split it.
+ if ((ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8)) {
EVT ReducedVTI32 =
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotI32 =
- DAG.getNode(Opcode, DL, ReducedVTI32,
- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+ DAG.getNode(DotOpcode, DL, ReducedVTI32,
+ DAG.getConstant(0, DL, ReducedVTI32), Op1, Op2);
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}
- return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+ unsigned NewOpcode =
+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, Op1, Op2);
}
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
+SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &ExtOp1, SDValue &Op2,
+ SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
return SDValue();
+ EVT AccVT = Op0->getValueType(0);
+ unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+ if (!ISD::isExtOpcode(ExtOp1Opcode))
+ return SDValue();
+ SDValue Op1 = ExtOp1->getOperand(0);
+ EVT Op1VT = Op1.getValueType();
- SDLoc DL(N);
-
- if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
+ unsigned Op2Opcode = Op2->getOpcode();
+ if (Op2Opcode != ISD::SPLAT_VECTOR || !isOneConstant(Op2->getOperand(0)))
return SDValue();
- SDValue Acc = N->getOperand(1);
- SDValue Ext = N->getOperand(2);
- EVT AccVT = Acc.getValueType();
- EVT ExtVT = Ext.getValueType();
- if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
+ Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
+
+ if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+ !(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+ !(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- SDValue ExtOp = Ext->getOperand(0);
- EVT ExtOpVT = ExtOp.getValueType();
+ unsigned NewOpcode = ExtOp1Opcode == ISD::SIGN_EXTEND
+ ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
- if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
- !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
- !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
- return SDValue();
+ return DAG.getNode(NewOpcode, DL, AccVT, Op0, Op1, Op2);
+}
- bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
- unsigned BottomOpcode =
- ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
+SDValue performPartialReduceMLACombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ SDLoc DL(N);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+ EVT Op0ElemVT = Op0.getValueType().getVectorElementType();
+ EVT Op1ElemVT = Op1.getValueType().getVectorElementType();
+
+ // If the exts have already been removed or it has already been lowered to an
+ // usdot instruction, then the element types will not be equal
+ if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
+ return SDValue(N, 0);
+
+ if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
+ Op0 = MLA->getOperand(0);
+ Op1 = MLA->getOperand(1);
+ Op2 = MLA->getOperand(2);
+ }
+ if (auto Dot = tryCombineToDotProduct(Op0, Op1, Op2, DAG, Subtarget, DL))
+ return Dot;
+ if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
+ return WideAdd;
+ // N->getOperand needs calling again because the Op variables may have been
+ // changed by the functions above
+ return DAG.expandPartialReduceMLA(DL, N->getOperand(0), N->getOperand(1),
----------------
sdesmalen-arm wrote:
The default should not be to expand the reduction here. This is a target-specific DAGCombine that tries to optimise the DAG for AArch64-specific use-cases. Expansion should only happen if the optimized DAG cannot be lowered or legalized.
```
v4i32 PARTIAL_REDUCE_UMLA(v4i32 Acc, v8i32 SEXT(v8i16 X), v8i32 SEXT(v8i16 Y))
->
v4i32 PARTIAL_REDUCE_SMLA(v4i32 Acc, v8i16 X, v8i16 Y)
```
In the case where the extends cannot be recognised, e.g.
```
v4i32 PARTIAL_REDUCE_UMLA(v4i32 Acc, v8i32 X, v8i32 Y)
```
Then this would require type legalisation (splitting, to break up the v8i32 -> 2 x v4i32), which for now could fall back to `expandPartialReduce`.
The case that can't be represented with these new nodes are the USDOT instructions. Those you could lower in this function to a custom AArch64ISD node.
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list