[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 26 11:32:39 PST 2025


================
@@ -22011,138 +22010,188 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
-                                      const AArch64Subtarget *Subtarget,
-                                      SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
-  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+SDValue tryCombinePartialReduceMLAMulOp(SDValue &Op0, SDValue &Op1,
+                                        SDValue &Op2, SelectionDAG &DAG,
+                                        SDLoc &DL) {
+  // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat (1))
+  // into PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))
+  if (Op1->getOpcode() != ISD::MUL)
     return SDValue();
-  if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+
+  SDValue ExtMulOpLHS = Op1->getOperand(0);
+  SDValue ExtMulOpRHS = Op1->getOperand(1);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
     return SDValue();
 
-  SDLoc DL(N);
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
 
-  SDValue Op2 = N->getOperand(2);
   unsigned Op2Opcode = Op2->getOpcode();
-  SDValue MulOpLHS, MulOpRHS;
-  bool MulOpLHSIsSigned, MulOpRHSIsSigned;
-  if (ISD::isExtOpcode(Op2Opcode)) {
-    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
-    MulOpLHS = Op2->getOperand(0);
-    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
-  } else if (Op2Opcode == ISD::MUL) {
-    SDValue ExtMulOpLHS = Op2->getOperand(0);
-    SDValue ExtMulOpRHS = Op2->getOperand(1);
-
-    unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
-    unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
-    if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
-        !ISD::isExtOpcode(ExtMulOpRHSOpcode))
-      return SDValue();
+  if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
+      !isOneConstant(Op2->getOperand(0)))
----------------
sdesmalen-arm wrote:

Please use `ISD::isConstantSplatVector` instead, and then check that the APInt is one using `isOne()`.

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list