[llvm] [AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (PR #138269)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Fri May 2 09:23:14 PDT 2025


https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/138269

>From 3187e07e23afcd6e7e5e93cf634d909d595460bd Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Fri, 2 May 2025 13:57:56 +0100
Subject: [PATCH 1/3] [AArch64][SelectionDAG] Add support for 8to64 partial
 reduction cases

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 28 ++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 .../AArch64/sve-partial-reduce-dot-product.ll | 86 +++----------------
 3 files changed, 41 insertions(+), 74 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0126b97c9fb9a..ca63473ab085b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1867,6 +1867,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     // Other pairs will default to 'Expand'.
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -7767,6 +7769,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerFLDEXP(Op, DAG);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
 
@@ -29509,6 +29514,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+
+  auto Acc = Op.getOperand(0);
+  auto LHS = Op.getOperand(1);
+  auto RHS = Op.getOperand(2);
+
+  auto ResultVT = Op.getValueType();
+
+  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+
+  auto NewAcc = DAG.getConstant(0, DL, MVT::nxv4i32);
+  auto DotNode =
+      DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, NewAcc, LHS, RHS);
+
+  auto Lo = DAG.getNode(AArch64ISD::UUNPKLO, DL, ResultVT, DotNode);
+  auto Hi = DAG.getNode(AArch64ISD::UUNPKHI, DL, ResultVT, DotNode);
+  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+}
+
 SDValue
 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
                                                     SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 039cac01008b8..4d5996ee955ce 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -198,43 +198,12 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ;
 ; CHECK-NEWLOWERING-LABEL: udot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -258,43 +227,12 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ;
 ; CHECK-NEWLOWERING-LABEL: sdot_8to64:
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z7.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z26.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z28.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    sunpklo z27.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z29.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z6.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z7.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z26.d, z28.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT:    mla z1.d, p0/m, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z2.d, z3.d
+; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

>From 00371dfadfc9adcdb05dc3479dcf41385c6da34a Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Fri, 2 May 2025 16:08:31 +0100
Subject: [PATCH 2/3] Switch to using wide adds instead of unpacking

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 17 ++++++++---------
 .../AArch64/sve-partial-reduce-dot-product.ll   | 12 ++++--------
 2 files changed, 12 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ca63473ab085b..d3bbe7a935dd4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29522,19 +29522,18 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   auto Acc = Op.getOperand(0);
   auto LHS = Op.getOperand(1);
   auto RHS = Op.getOperand(2);
-
   auto ResultVT = Op.getValueType();
-
   assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
-  auto NewAcc = DAG.getConstant(0, DL, MVT::nxv4i32);
-  auto DotNode =
-      DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, NewAcc, LHS, RHS);
+  EVT InputVT = MVT::nxv4i32;
+  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, InputVT,
+                                DAG.getConstant(0, DL, InputVT), LHS, RHS);
 
-  auto Lo = DAG.getNode(AArch64ISD::UUNPKLO, DL, ResultVT, DotNode);
-  auto Hi = DAG.getNode(AArch64ISD::UUNPKHI, DL, ResultVT, DotNode);
-  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
-  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+  bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
+  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
+  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
+  SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
+  return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
 }
 
 SDValue
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 4d5996ee955ce..811202e5c9662 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -200,10 +200,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
 ; CHECK-NEWLOWERING-NEXT:    udot z4.s, z2.b, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    uaddwt z0.d, z0.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -229,10 +227,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEWLOWERING:       // %bb.0: // %entry
 ; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
 ; CHECK-NEWLOWERING-NEXT:    sdot z4.s, z2.b, z3.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    add z2.d, z3.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-NEXT:    saddwt z0.d, z0.d, z4.s
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

>From 1d0994befad01a1e64e215b672274983a81d6fdc Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Fri, 2 May 2025 17:22:07 +0100
Subject: [PATCH 3/3] Address comments and add SVE2/SME guard & fallback

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 35 +++++++++----
 .../AArch64/sve-partial-reduce-dot-product.ll | 51 +++++++++++++------
 2 files changed, 60 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d3bbe7a935dd4..373ae1f154175 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29514,26 +29514,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator type that is too
+/// wide to be used for (u|s)dot, we can still make use of the dot product
+/// instruction by instead treating the accumulator as a vector type with twice
+/// as many elements that are each half as wide, accumulating the low and high
+/// parts of the result together in the actual accumulator afterwards.
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
   SDLoc DL(Op);
 
-  auto Acc = Op.getOperand(0);
-  auto LHS = Op.getOperand(1);
-  auto RHS = Op.getOperand(2);
-  auto ResultVT = Op.getValueType();
+  SDValue Acc = Op.getOperand(0);
+  SDValue LHS = Op.getOperand(1);
+  SDValue RHS = Op.getOperand(2);
+  EVT ResultVT = Op.getValueType();
   assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
-  EVT InputVT = MVT::nxv4i32;
-  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, InputVT,
-                                DAG.getConstant(0, DL, InputVT), LHS, RHS);
+  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
+                                DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
 
   bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
-  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
-  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
-  SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
-  return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
+  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+    unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
+    unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
+    SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
+    return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
+  }
+
+  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+  auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+  auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
 }
 
 SDValue
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 811202e5c9662..b4552553220b5 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,7 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -196,13 +197,23 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: udot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
-; CHECK-NEWLOWERING-NEXT:    udot z4.s, z2.b, z3.b
-; CHECK-NEWLOWERING-NEXT:    uaddwb z0.d, z0.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uaddwt z0.d, z0.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -223,13 +234,23 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sdot_8to64:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    movi v4.2d, #0000000000000000
-; CHECK-NEWLOWERING-NEXT:    sdot z4.s, z2.b, z3.b
-; CHECK-NEWLOWERING-NEXT:    saddwb z0.d, z0.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    saddwt z0.d, z0.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>



More information about the llvm-commits mailing list