[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 20 07:19:21 PST 2025
================
@@ -29164,6 +29188,32 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Acc = Op.getOperand(0);
+ SDValue Input1 = Op.getOperand(1);
+ SDValue Input2 = Op.getOperand(2);
+
+ EVT AccVT = Acc.getValueType();
+ EVT InputVT = Input1.getValueType();
+
+ unsigned Opcode = Op.getOpcode();
+
+ if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
+ : AArch64ISD::UDOT;
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
+ }
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
+ unsigned BottomOpcode =
+ InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
----------------
JamesChesterman wrote:
I've now made it so a `MUL` instruction is made here if `Input2` is not a splat vector of constant 1s. A `MUL` instruction is still made in the DAG combine because then it can be removed if an operand is a vector of constant 1s, or can be made into a shift if an operand is a power of 2. This would not happen if a `MUL` instruction was just made in `LowerPARTIAL_REDUCE_MLA`. In the DAG combine, Input2 is set to a vector of constant 1s, so the `MUL` is not repeated.
https://github.com/llvm/llvm-project/pull/117185
More information about the llvm-commits
mailing list