[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 7 06:13:03 PST 2025


================
@@ -29164,6 +29188,32 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Acc = Op.getOperand(0);
+  SDValue Input1 = Op.getOperand(1);
+  SDValue Input2 = Op.getOperand(2);
+
+  EVT AccVT = Acc.getValueType();
+  EVT InputVT = Input1.getValueType();
+
+  unsigned Opcode = Op.getOpcode();
+
+  if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+    unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
+                                                            : AArch64ISD::UDOT;
+    return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
+  }
+  bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
+  unsigned BottomOpcode =
+      InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
----------------
MacDue wrote:

I don't think this should assume `Input2` is a vector of all ones if `AccVT.getVectorElementCount() * 4 != InputVT.getVectorElementCount()`.

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list