[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 20 08:39:39 PST 2025


================
@@ -22072,77 +22067,102 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     return SDValue();
 
   // If the extensions are mixed, we should lower it to a usdot instead
-  unsigned Opcode = 0;
+  unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
   if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = ReducedVT.isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
       return SDValue();
 
-    Opcode = AArch64ISD::USDOT;
-    // USDOT expects the signed operand to be last
     if (!MulOpRHSIsSigned)
       std::swap(MulOpLHS, MulOpRHS);
-  } else
-    Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+    DotOpcode = AArch64ISD::USDOT;
+    // Lower usdot patterns here because legalisation would attempt to split it
+    // unless exts are removed. But, removing the exts would lose the
+    // information about whether each operand is signed.
+    if ((ReducedVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
+        (ReducedVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
+      return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+  }
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
-  // product followed by a zero / sign extension
+  // product followed by a zero / sign extension. Need to lower this here
+  // because legalisation would attempt to split it.
   if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
       (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
     EVT ReducedVTI32 =
         (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
     SDValue DotI32 =
-        DAG.getNode(Opcode, DL, ReducedVTI32,
+        DAG.getNode(DotOpcode, DL, ReducedVTI32,
                     DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
     SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
     return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+  unsigned NewOpcode =
+      MulOpLHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
-                                          const AArch64Subtarget *Subtarget,
-                                          SelectionDAG &DAG) {
-
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
+SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+                            SelectionDAG &DAG,
+                            const AArch64Subtarget *Subtarget, SDLoc &DL) {
   if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
     return SDValue();
+  unsigned Op1Opcode = Op1->getOpcode();
+  if (!ISD::isExtOpcode(Op1Opcode))
+    return SDValue();
 
-  SDLoc DL(N);
+  EVT AccVT = Op0->getValueType(0);
+  Op1 = Op1->getOperand(0);
+  EVT Op1VT = Op1.getValueType();
+  // Makes Op2's value type match the value type of Op1 without its extend.
+  Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
+  // Make a MUL between Op1 and Op2 here so the MUL can be changed if possible
+  // (can be pruned or changed to a shift instruction for example).
+  SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
 
-  if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
+  if (!AccVT.isScalableVector())
     return SDValue();
-  SDValue Acc = N->getOperand(1);
-  SDValue Ext = N->getOperand(2);
-  EVT AccVT = Acc.getValueType();
-  EVT ExtVT = Ext.getValueType();
-  if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
+
+  if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+      !(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+      !(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  SDValue ExtOp = Ext->getOperand(0);
-  EVT ExtOpVT = ExtOp.getValueType();
+  unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
+                                                     : ISD::PARTIAL_REDUCE_UMLA;
+  // Return a constant of 1s for Op2 so the MUL is not performed again.
+  return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
+                     DAG.getConstant(1, DL, Op1VT));
+}
 
-  if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
-      !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
-      !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
-    return SDValue();
+SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
----------------
MacDue wrote:

nit:
```suggestion
SDValue performPartialReduceMLACombine(SDNode *N, SelectionDAG &DAG,
```

https://github.com/llvm/llvm-project/pull/117185


More information about the llvm-commits mailing list