[llvm] [SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes (PR #140075)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Thu May 15 09:00:26 PDT 2025


================
@@ -29518,37 +29538,64 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 }
 
 /// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// of v2i64/v16i8, we cannot directly lower it to a (u|s)dot. We can
 /// however still make use of the dot product instruction by instead
-/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+/// accumulating over two steps: v16i8 -> v4i32 -> v2i64.
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
+  bool Scalable = Op.getValueType().isScalableVector();
+  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+  if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+    return SDValue();
+
   SDLoc DL(Op);
 
   SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
-  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
-                                DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+  assert((Scalable && ResultVT == MVT::nxv2i64 &&
+          LHS.getValueType() == MVT::nxv16i8) ||
+         (!Scalable && ResultVT == MVT::v2i64 &&
+          LHS.getValueType() == MVT::v16i8));
+
+  EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
+                                DAG.getConstant(0, DL, DotVT), LHS, RHS);
 
   bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
-  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+  if (Scalable &&
+      (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
     unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
     unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
     SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
     return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
   }
 
-  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
-  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
-  auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
-  auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
-  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
-  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+  if (Scalable) {
+    unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+    unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+    auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+    auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+    auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+    return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+  }
+
+  // Fold v4i32 into v2i64
+  // SDValues
+  auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+  if (IsUnsigned) {
+    DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+    DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+  } else {
+    DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+    DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+  }
+  auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo);
+  return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi);
----------------
MacDue wrote:

It looks like this would work for both Neon and SVE. Any reason this is not just?:
```suggestion
  // Fold (nx)v4i32 into (nx)v2i64
  auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
  if (IsUnsigned) {
    DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
    DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
  } else {
    DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
    DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
  }
  auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
  return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
```

https://github.com/llvm/llvm-project/pull/140075


More information about the llvm-commits mailing list