[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 23 06:34:09 PST 2025


================
@@ -22011,34 +22010,25 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
-                                      const AArch64Subtarget *Subtarget,
-                                      SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+                               SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget, SDLoc &DL) {
+  bool Scalable = Op0->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
-
-  SDValue Op2 = N->getOperand(2);
-  unsigned Op2Opcode = Op2->getOpcode();
+  unsigned Op1Opcode = Op1->getOpcode();
   SDValue MulOpLHS, MulOpRHS;
   bool MulOpLHSIsSigned, MulOpRHSIsSigned;
-  if (ISD::isExtOpcode(Op2Opcode)) {
-    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
-    MulOpLHS = Op2->getOperand(0);
-    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
-  } else if (Op2Opcode == ISD::MUL) {
-    SDValue ExtMulOpLHS = Op2->getOperand(0);
-    SDValue ExtMulOpRHS = Op2->getOperand(1);
+  if (ISD::isExtOpcode(Op1Opcode)) {
+    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
+    MulOpLHS = Op1->getOperand(0);
+    MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
+  } else if (Op1Opcode == ISD::MUL) {
+    SDValue ExtMulOpLHS = Op1->getOperand(0);
+    SDValue ExtMulOpRHS = Op1->getOperand(1);
----------------
JamesChesterman wrote:

Done. Now there is a separate function for lowering:
`PARTIAL_REDUCE_MLA (Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), SPLAT (1))`
To:
`PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))`
The function `tryCombineToDotProduct` can then handle this pattern (and removes the extends) as well as the pattern for `PARTIAL_REDUCE_MLA(Acc, EXT(Op), SPLAT (1))`.
I've added comments in the relevant places detailing what happens.

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list