[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 20 07:19:21 PST 2025


================
@@ -29164,6 +29188,32 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Acc = Op.getOperand(0);
+  SDValue Input1 = Op.getOperand(1);
+  SDValue Input2 = Op.getOperand(2);
+
+  EVT AccVT = Acc.getValueType();
+  EVT InputVT = Input1.getValueType();
+
+  unsigned Opcode = Op.getOpcode();
+
+  if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+    unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
+                                                            : AArch64ISD::UDOT;
+    return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
+  }
+  bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
+  unsigned BottomOpcode =
+      InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
----------------
JamesChesterman wrote:

I've now made it so a `MUL` instruction is made here if `Input2` is not a splat vector of constant 1s. A `MUL` instruction is still made in the DAG combine because then it can be removed if an operand is a vector of constant 1s, or can be made into a shift if an operand is a power of 2. This would not happen if a `MUL` instruction was just made in `LowerPARTIAL_REDUCE_MLA`. In the DAG combine, Input2 is set to a vector of constant 1s, so the `MUL` is not repeated.

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list