[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 20 08:37:02 PST 2025


================
@@ -22011,34 +22010,25 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
-                                      const AArch64Subtarget *Subtarget,
-                                      SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+                               SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget, SDLoc &DL) {
+  bool Scalable = Op0->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
-
-  SDValue Op2 = N->getOperand(2);
-  unsigned Op2Opcode = Op2->getOpcode();
+  unsigned Op1Opcode = Op1->getOpcode();
   SDValue MulOpLHS, MulOpRHS;
   bool MulOpLHSIsSigned, MulOpRHSIsSigned;
-  if (ISD::isExtOpcode(Op2Opcode)) {
-    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
-    MulOpLHS = Op2->getOperand(0);
-    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
-  } else if (Op2Opcode == ISD::MUL) {
-    SDValue ExtMulOpLHS = Op2->getOperand(0);
-    SDValue ExtMulOpRHS = Op2->getOperand(1);
+  if (ISD::isExtOpcode(Op1Opcode)) {
+    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
+    MulOpLHS = Op1->getOperand(0);
+    MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
+  } else if (Op1Opcode == ISD::MUL) {
+    SDValue ExtMulOpLHS = Op1->getOperand(0);
+    SDValue ExtMulOpRHS = Op1->getOperand(1);
----------------
MacDue wrote:

I'm a little confused by what this fold is doing now. `PARTIAL_REDUCE_MLA` is `acc += partial_reduce(op1 * op2)`. But this combine is also checking for another `MUL` node. I think this would be simpler to follow if there was a fold that took:


```
mul = op1 * op2
ret = PARTIAL_REDUCE_MLA(mul * splat (1))
```

And folded it to:

```
ret = PARTIAL_REDUCE_MLA(op1, op2)
```

Then the dot product lowering can be simpler (since it does not need to worry about the `MUL` node). 

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list