[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 26 11:32:39 PST 2025


================
@@ -22011,138 +22010,188 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
-                                      const AArch64Subtarget *Subtarget,
-                                      SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
-  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+SDValue tryCombinePartialReduceMLAMulOp(SDValue &Op0, SDValue &Op1,
+                                        SDValue &Op2, SelectionDAG &DAG,
+                                        SDLoc &DL) {
+  // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat (1))
+  // into PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))
+  if (Op1->getOpcode() != ISD::MUL)
     return SDValue();
-  if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+
+  SDValue ExtMulOpLHS = Op1->getOperand(0);
+  SDValue ExtMulOpRHS = Op1->getOperand(1);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
     return SDValue();
 
-  SDLoc DL(N);
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
 
-  SDValue Op2 = N->getOperand(2);
   unsigned Op2Opcode = Op2->getOpcode();
-  SDValue MulOpLHS, MulOpRHS;
-  bool MulOpLHSIsSigned, MulOpRHSIsSigned;
-  if (ISD::isExtOpcode(Op2Opcode)) {
-    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
-    MulOpLHS = Op2->getOperand(0);
-    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
-  } else if (Op2Opcode == ISD::MUL) {
-    SDValue ExtMulOpLHS = Op2->getOperand(0);
-    SDValue ExtMulOpRHS = Op2->getOperand(1);
-
-    unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
-    unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
-    if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
-        !ISD::isExtOpcode(ExtMulOpRHSOpcode))
-      return SDValue();
+  if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
+      !isOneConstant(Op2->getOperand(0)))
+    return SDValue();
 
-    MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
-    MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, Op0->getValueType(0), Op0,
+                     ExtMulOpLHS, ExtMulOpRHS);
+}
 
-    MulOpLHS = ExtMulOpLHS->getOperand(0);
-    MulOpRHS = ExtMulOpRHS->getOperand(0);
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &ExtOp1, SDValue &ExtOp2,
+                               SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget, SDLoc &DL) {
+  bool Scalable = Op0->getValueType(0).isScalableVector();
+  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+  if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+    return SDValue();
 
-    if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+  unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+  unsigned ExtOp2Opcode = ExtOp2->getOpcode();
+  SDValue Op1, Op2;
+  bool Op1IsSigned, Op2IsSigned;
+  if (!ISD::isExtOpcode(ExtOp1Opcode))
+    return SDValue();
+  Op1 = ExtOp1->getOperand(0);
+  EVT SrcVT = Op1.getValueType();
+
+  if ((ExtOp2Opcode == ISD::SPLAT_VECTOR ||
+       ExtOp2Opcode == ISD::BUILD_VECTOR) &&
+      isOneConstant(ExtOp2.getOperand(0))) {
+    // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+    // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
+    Op1IsSigned = Op2IsSigned = (ExtOp1Opcode == ISD::SIGN_EXTEND);
+    // Can only do this because it's a splat vector of constant 1
+    Op2 = DAG.getAnyExtOrTrunc(ExtOp2, DL, SrcVT);
+  } else if (ISD::isExtOpcode(ExtOp2Opcode)) {
+    // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Ext(Op2)) into
+    // PARTIAL_REDUCE_MLA(Acc, Op1, Op2)
+    Op2 = ExtOp2->getOperand(0);
+    Op1IsSigned = ExtOp1Opcode == ISD::SIGN_EXTEND;
+    Op2IsSigned = ExtOp2Opcode == ISD::SIGN_EXTEND;
+    if (SrcVT != Op2.getValueType())
       return SDValue();
-  } else
+  } else {
     return SDValue();
+  }
 
-  SDValue Acc = N->getOperand(1);
-  EVT ReducedVT = N->getValueType(0);
-  EVT MulSrcVT = MulOpLHS.getValueType();
+  SDValue Acc = Op0;
+  EVT ReducedVT = Acc->getValueType(0);
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
-      !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+  if (!(ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) &&
+      !(ReducedVT == MVT::nxv4i32 && SrcVT == MVT::nxv16i8) &&
+      !(ReducedVT == MVT::nxv2i64 && SrcVT == MVT::nxv8i16) &&
+      !(ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8) &&
+      !(ReducedVT == MVT::v4i32 && SrcVT == MVT::v16i8) &&
+      !(ReducedVT == MVT::v2i32 && SrcVT == MVT::v8i8))
     return SDValue();
 
   // If the extensions are mixed, we should lower it to a usdot instead
-  unsigned Opcode = 0;
-  if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
+  unsigned DotOpcode = Op1IsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+  if (Op1IsSigned != Op2IsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = ReducedVT.isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
       return SDValue();
 
-    Opcode = AArch64ISD::USDOT;
-    // USDOT expects the signed operand to be last
-    if (!MulOpRHSIsSigned)
-      std::swap(MulOpLHS, MulOpRHS);
-  } else
-    Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+    if (!Op2IsSigned)
+      std::swap(Op1, Op2);
+    DotOpcode = AArch64ISD::USDOT;
+    // Lower usdot patterns here because legalisation would attempt to split it
+    // unless exts are removed. But, removing the exts would lose the
+    // information about whether each operand is signed.
+    if ((ReducedVT != MVT::nxv4i64 || SrcVT != MVT::nxv16i8) &&
+        (ReducedVT != MVT::v4i64 || SrcVT != MVT::v16i8))
+      return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, Op1, Op2);
+  }
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
-  // product followed by a zero / sign extension
-  if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
-      (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+  // product followed by a zero / sign extension. Need to lower this here
+  // because legalisation would attempt to split it.
+  if ((ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) ||
+      (ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8)) {
     EVT ReducedVTI32 =
         (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
     SDValue DotI32 =
-        DAG.getNode(Opcode, DL, ReducedVTI32,
-                    DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+        DAG.getNode(DotOpcode, DL, ReducedVTI32,
+                    DAG.getConstant(0, DL, ReducedVTI32), Op1, Op2);
     SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
     return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+  unsigned NewOpcode =
+      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, Op1, Op2);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
-                                          const AArch64Subtarget *Subtarget,
-                                          SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
+SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &ExtOp1, SDValue &Op2,
+                            SelectionDAG &DAG,
+                            const AArch64Subtarget *Subtarget, SDLoc &DL) {
+  // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+  // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
   if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
     return SDValue();
+  EVT AccVT = Op0->getValueType(0);
+  unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+  if (!ISD::isExtOpcode(ExtOp1Opcode))
+    return SDValue();
+  SDValue Op1 = ExtOp1->getOperand(0);
+  EVT Op1VT = Op1.getValueType();
 
-  SDLoc DL(N);
-
-  if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
+  unsigned Op2Opcode = Op2->getOpcode();
+  if (Op2Opcode != ISD::SPLAT_VECTOR || !isOneConstant(Op2->getOperand(0)))
     return SDValue();
-  SDValue Acc = N->getOperand(1);
-  SDValue Ext = N->getOperand(2);
-  EVT AccVT = Acc.getValueType();
-  EVT ExtVT = Ext.getValueType();
-  if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
+  Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
+
+  if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+      !(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+      !(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  SDValue ExtOp = Ext->getOperand(0);
-  EVT ExtOpVT = ExtOp.getValueType();
+  unsigned NewOpcode = ExtOp1Opcode == ISD::SIGN_EXTEND
+                           ? ISD::PARTIAL_REDUCE_SMLA
+                           : ISD::PARTIAL_REDUCE_UMLA;
 
-  if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
-      !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
-      !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
-    return SDValue();
+  return DAG.getNode(NewOpcode, DL, AccVT, Op0, Op1, Op2);
+}
 
-  bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
-  unsigned BottomOpcode =
-      ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
-  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
+SDValue performPartialReduceMLACombine(SDNode *N, SelectionDAG &DAG,
+                                       const AArch64Subtarget *Subtarget) {
+  SDLoc DL(N);
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+  EVT Op0ElemVT = Op0.getValueType().getVectorElementType();
+  EVT Op1ElemVT = Op1.getValueType().getVectorElementType();
+
+  // If the exts have already been removed or it has already been lowered to an
+  // usdot instruction, then the element types will not be equal
+  if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
+    return SDValue(N, 0);
+
+  if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
----------------
sdesmalen-arm wrote:

This seems like a generically useful combine, so it can be moved to the generic combine code in `DAGCombiner.cpp`

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list