[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 7 06:13:03 PST 2025
================
@@ -29164,6 +29188,32 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Acc = Op.getOperand(0);
+ SDValue Input1 = Op.getOperand(1);
+ SDValue Input2 = Op.getOperand(2);
+
+ EVT AccVT = Acc.getValueType();
+ EVT InputVT = Input1.getValueType();
+
+ unsigned Opcode = Op.getOpcode();
+
+ if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
+ : AArch64ISD::UDOT;
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
+ }
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
+ unsigned BottomOpcode =
+ InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
----------------
MacDue wrote:
I don't think this should assume `Input2` is a vector of all ones if `AccVT.getVectorElementCount() * 4 != InputVT.getVectorElementCount()`.
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list