[llvm] [AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT (PR #131327)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 07:57:10 PDT 2025


https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/131327

>From a20bced7525462a1ceaa91e60f8a934995facb29 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 6 May 2025 16:59:30 +0100
Subject: [PATCH 1/5] [AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to
 USDOT

Add lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions.
This happens when there is a MUL instruction as the second operand
in the ISD node. Then the extends on the operands of the MUL op
need to have a different signedness.
---
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    |  15 ++-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  80 +++++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   1 +
 .../AArch64/sve-partial-reduce-dot-product.ll | 122 ++----------------
 4 files changed, 107 insertions(+), 111 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 83fade45d1892..1af60d6896e6d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
 /// illegal ResNo in that case.
 bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
   // See if the target wants to custom lower this node.
-  if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
-    return false;
+  unsigned Opcode = N->getOpcode();
+  bool IsPRMLAOpcode =
+      Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
+
+  if (IsPRMLAOpcode) {
+    if (TLI.getPartialReduceMLAAction(N->getValueType(0),
+                                      N->getOperand(1).getValueType()) !=
+        TargetLowering::Custom)
+      return false;
+  } else {
+    if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+      return false;
+  }
 
   SmallVector<SDValue, 8> Results;
   if (LegalizeResult)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6e7f13c20db68..cbf4c89781b96 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7742,8 +7742,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerFLDEXP(Op, DAG);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
-  case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
@@ -27533,6 +27533,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
     if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
       Results.push_back(Res);
     return;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG));
+    return;
   case ISD::ADD:
   case ISD::FADD:
     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29515,6 +29519,80 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
 }
 
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpLHS, MulOpRHS)
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpRHS, MulOpLHS)
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  bool Scalable = Op.getValueType().isScalableVector();
+  auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+  if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
+    return SDValue();
+  if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd()))
+    return SDValue();
+  if (!Subtarget.hasMatMulInt8())
+    return SDValue();
+  SDLoc DL(Op);
+
+  if (Op.getOperand(1).getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue Acc = Op.getOperand(0);
+  SDValue Mul = Op.getOperand(1);
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  SDValue ExtMulOpLHS = Mul.getOperand(0);
+  SDValue ExtMulOpRHS = Mul.getOperand(1);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+    return SDValue();
+
+  SDValue MulOpLHS = ExtMulOpLHS.getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS.getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
+
+  bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+  bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  if (LHSIsSigned == RHSIsSigned)
+    return SDValue();
+
+  EVT AccVT = Acc.getValueType();
+  // There is no nxv2i64 version of usdot
+  if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
+    return SDValue();
+
+  // USDOT expects the signed operand to be last
+  if (!RHSIsSigned)
+    std::swap(MulOpLHS, MulOpRHS);
+
+  unsigned Opcode = AArch64ISD::USDOT;
+  // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+  // product followed by a zero / sign extension
+  // Don't want this to be split because there is no nxv2i64 version of usdot
+  if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
+      (AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
+    EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+    SDValue DotI32 =
+        DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
+                    MulOpLHS, MulOpRHS);
+    SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
+    return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
+  }
+
+  return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS);
+}
+
 SDValue
 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
                                                     SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9d8d1c22258be..6f0fb03bae0ea 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1182,6 +1182,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..b7127370d2415 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: usdot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z1.b, z2.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
 ;
 ; CHECK-NEWLOWERING-LABEL: sudot:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT:    usdot z0.s, z2.b, z1.b
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -331,43 +299,12 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ;
 ; CHECK-NEWLOWERING-LABEL: usdot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT:    usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -432,43 +369,12 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ;
 ; CHECK-NEWLOWERING-LABEL: sudot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT:    usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

>From 7a62406e97fda3479f0d3168e74a4320ece94801 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 6 May 2025 17:00:42 +0100
Subject: [PATCH 2/5] Add calls to setPartialReduceMLAAction.

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cbf4c89781b96..d92c57885ca41 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1869,7 +1869,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
 
+    // 8to64
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+    // USDOT
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+    setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+    setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
   }
 
   // Handle operations that are only available in non-streaming SVE mode.

>From 7868964015736bb88d1db3243f5eab440b16df69 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 7 May 2025 17:45:38 +0100
Subject: [PATCH 3/5] Rebase and update tests

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 48 +++++-----
 .../AArch64/sve-partial-reduce-dot-product.ll | 88 ++++++++++++++++---
 2 files changed, 101 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d92c57885ca41..8ee62844d4c99 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1872,14 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     // 8to64
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
 
-    // USDOT
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
-
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
-    setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
-    setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
+    if (Subtarget->hasMatMulInt8()) {
+      // USDOT
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+      setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
+    }
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -29495,21 +29496,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
-/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
-/// however still make use of the dot product instruction by instead
-/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
-  SDLoc DL(Op);
+  if (SDValue UsdotNode = LowerPARTIAL_REDUCE_MLAToUSDOT(Op, DAG))
+    return UsdotNode;
 
-  SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
-  SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+  /// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type
+  /// pairing of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We
+  /// can however still make use of the dot product instruction by instead
+  /// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+  if (ResultVT != MVT::nxv2i64 || LHS.getValueType() != MVT::nxv16i8)
+    return SDValue();
 
+  SDLoc DL(Op);
+  SDValue Acc = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(2);
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
                                 DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
 
@@ -29529,13 +29533,13 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
 }
 
-// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
-// to USDOT(Acc, MulOpLHS, MulOpRHS)
-// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
-// to USDOT(Acc, MulOpRHS, MulOpLHS)
+// partial.reduce.umla(acc, mul(zext(mulOpLHS), sext(mulOpRHS)), splat(1))
+// -> USDOT(acc, mulOpLHS, mulOpRHS)
+// partial.reduce.smla(acc, mul(sext(mulOpLHS), zext(mulOpRHS)), splat(1))
+// -> USDOT(acc, mulOpRHS, mulOpLHS)
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
-                                               SelectionDAG &DAG) const {
+                                                      SelectionDAG &DAG) const {
   bool Scalable = Op.getValueType().isScalableVector();
   auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
   if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
@@ -29591,7 +29595,7 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
   // Don't want this to be split because there is no nxv2i64 version of usdot
   if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
       (AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
-    EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+    EVT AccVTI32 = AccVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
 
     SDValue DotI32 =
         DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index b7127370d2415..0c5ec2908a16d 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -3,7 +3,7 @@
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -299,12 +299,43 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ;
 ; CHECK-NEWLOWERING-LABEL: usdot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEWLOWERING-NEXT:    usdot z4.s, z2.b, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -369,12 +400,43 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
 ;
 ; CHECK-NEWLOWERING-LABEL: sudot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEWLOWERING-NEXT:    usdot z4.s, z3.b, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

>From d40773d22124ae66f4d90db9510169eb8931ca2f Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Thu, 8 May 2025 15:06:26 +0100
Subject: [PATCH 4/5] Adjust how usdot cases are lowered

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp | 15 ++-------------
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp |  2 ++
 2 files changed, 4 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 1af60d6896e6d..e0e0bf8777d87 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -925,18 +925,8 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
 bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
   // See if the target wants to custom lower this node.
   unsigned Opcode = N->getOpcode();
-  bool IsPRMLAOpcode =
-      Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
-
-  if (IsPRMLAOpcode) {
-    if (TLI.getPartialReduceMLAAction(N->getValueType(0),
-                                      N->getOperand(1).getValueType()) !=
-        TargetLowering::Custom)
-      return false;
-  } else {
-    if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
-      return false;
-  }
+  if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+    return false;
 
   SmallVector<SDValue, 8> Results;
   if (LegalizeResult)
@@ -957,7 +947,6 @@ bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
   return true;
 }
 
-
 /// Widen the node's results with custom code provided by the target and return
 /// "true", or do nothing and return "false".
 bool DAGTypeLegalizer::CustomWidenLowerNode(SDNode *N, EVT VT) {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8ee62844d4c99..c9f48c13305c6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1880,6 +1880,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
       setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
       setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
+
+      setOperationAction(ISD::PARTIAL_REDUCE_UMLA, MVT::nxv16i32, Custom);
     }
   }
 

>From 22636ac1a9bd8e73943f9b52dc5ff063dc5ed258 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Mon, 12 May 2025 15:56:06 +0100
Subject: [PATCH 5/5] Adjust how usdot cases are lowered

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp | 14 ++++++++++++--
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++-----------
 2 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index e0e0bf8777d87..4bb61261526f0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -925,8 +925,18 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
 bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
   // See if the target wants to custom lower this node.
   unsigned Opcode = N->getOpcode();
-  if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
-    return false;
+  bool IsPRMLAOpcode =
+      Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
+
+  if (IsPRMLAOpcode) {
+    if (TLI.getPartialReduceMLAAction(N->getValueType(0),
+                                      N->getOperand(1).getValueType()) !=
+        TargetLowering::Custom)
+      return false;
+  } else {
+    if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+      return false;
+  }
 
   SmallVector<SDValue, 8> Results;
   if (LegalizeResult)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9f48c13305c6..e98307fba88dd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1872,17 +1872,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     // 8to64
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
 
-    if (Subtarget->hasMatMulInt8()) {
-      // USDOT
-      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+    // USDOT
+    if (Subtarget->hasMatMulInt8())
       setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
-      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
-      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
-      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
-      setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
-
-      setOperationAction(ISD::PARTIAL_REDUCE_UMLA, MVT::nxv16i32, Custom);
-    }
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -7755,8 +7747,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerFLDEXP(Op, DAG);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
-  case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }



More information about the llvm-commits mailing list