[llvm] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (PR #107566)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 17 02:32:45 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/107566

>From ac88857b22b24a59b19d88dc38ddf5c11be5454d Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 3 Sep 2024 11:06:25 +0100
Subject: [PATCH 1/8] [AArch64][NEON][SVE] Lower mixed sign/zero extended
 partial reductions to usdot

This PR adds lowering for partial reductions of a mix of sign/zero
extended inputs to the usdot intrinsic.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  62 ++++--
 .../neon-partial-reduce-dot-product.ll        | 195 +++++++++++++++++-
 .../AArch64/sve-partial-reduce-dot-product.ll | 161 +++++++++++++--
 3 files changed, 382 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d80087336d230..a3b372e677f98a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21824,37 +21824,59 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   auto ExtA = MulOp->getOperand(0);
   auto ExtB = MulOp->getOperand(1);
-  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
-    return SDValue();
-
   auto A = ExtA->getOperand(0);
   auto B = ExtB->getOperand(0);
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  unsigned Opcode = 0;
-
-  if (IsSExt)
-    Opcode = AArch64ISD::SDOT;
-  else if (IsZExt)
-    Opcode = AArch64ISD::UDOT;
-
-  assert(Opcode != 0 && "Unexpected dot product case encountered.");
-
   EVT ReducedType = N->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
-      (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
-      (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
-      (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
-    return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+  if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+      !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+      !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+      !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+    return SDValue();
 
-  return SDValue();
+  bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+  if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
+    return SDValue();
+
+  // If the extensions are mixed, we should lower it to a usdot instead
+  if (AIsZExt != BIsZExt) {
+    if (!Subtarget->hasMatMulInt8())
+      return SDValue();
+    bool Scalable = N->getValueType(0).isScalableVT();
+
+    // There's no nxv2i64 version of usdot
+    if (Scalable && ReducedType != MVT::nxv4i32)
+      return SDValue();
+
+    unsigned IntrinsicID =
+        Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot;
+    // USDOT expects the first operand to be unsigned, so swap the operands if
+    // the first is signed and the second is unsigned
+    if (AIsSExt && BIsZExt)
+      std::swap(A, B);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ReducedType,
+                       DAG.getConstant(IntrinsicID, DL, MVT::i64), NarrowOp, A,
+                       B);
+  }
+
+  unsigned Opcode = 0;
+  if (AIsSExt)
+    Opcode = AArch64ISD::SDOT;
+  else if (AIsZExt)
+    Opcode = AArch64ISD::UDOT;
+
+  assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+  return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 8035504d5558b1..7b6c01f4691175 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
 
 define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-DOT-LABEL: udot:
@@ -18,6 +19,11 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: udot:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    udot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -45,6 +51,11 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: udot_narrow:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    udot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -68,6 +79,11 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: sdot:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -95,6 +111,11 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: sdot_narrow:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    sdot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -102,7 +123,175 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
   ret <2 x i32> %partial.reduce
 }
 
-define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
+define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: usdot:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT:    ret
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = sext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: usdot_narrow:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    usdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: usdot_narrow:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: usdot_narrow:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT:    ret
+  %u.wide = zext <8 x i8> %u to <8 x i32>
+  %s.wide = sext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    usdot v0.4s, v2.16b, v1.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sudot:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT:    ret
+  %u.wide = sext <16 x i8> %u to <16 x i32>
+  %s.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot_narrow:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    usdot v0.2s, v2.8b, v1.8b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sudot_narrow:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: sudot_narrow:
+; CHECK-NOIMM8:       // %bb.0:
+; CHECK-NOIMM8-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT:    ret
+  %u.wide = sext <8 x i8> %u to <8 x i32>
+  %s.wide = zext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index b1354ab210f727..35d2b8ca30a041 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,8 +1,9 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
 
-define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp:
+define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
@@ -14,8 +15,8 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide:
+define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot_wide:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
@@ -27,8 +28,8 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_sext:
+define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: sdot:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
@@ -40,8 +41,8 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide_sext:
+define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: sdot_wide:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
@@ -53,8 +54,80 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
-; CHECK-LABEL: not_dotp:
+define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    usdot z0.s, z1.b, z2.b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8:       // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT:    ptrue p0.s
+; CHECK-NOIMM8-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT:    movprfx z1, z3
+; CHECK-NOIMM8-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    usdot z0.s, z2.b, z1.b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8:       // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT:    ptrue p0.s
+; CHECK-NOIMM8-NEXT:    sunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT:    movprfx z1, z3
+; CHECK-NOIMM8-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
@@ -74,8 +147,8 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
-; CHECK-LABEL: not_dotp_wide:
+define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+; CHECK-LABEL: not_udot_wide:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z1.s, z1.s, #0xffff
 ; CHECK-NEXT:    and z2.s, z2.s, #0xffff
@@ -94,3 +167,65 @@ entry:
   %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }
+
+define <vscale x 2 x i64> @not_usdot(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_usdot:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpklo z4.s, z2.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_sudot(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_sudot:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}

>From b118e92f26d9dbe20834d280e47a248353ae0bcb Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 10 Sep 2024 15:02:50 +0100
Subject: [PATCH 2/8] Move ext opcode checks back

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a3b372e677f98a..c2899aa6aeac9c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21824,6 +21824,14 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   auto ExtA = MulOp->getOperand(0);
   auto ExtB = MulOp->getOperand(1);
+
+  bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+  if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
+    return SDValue();
+
   auto A = ExtA->getOperand(0);
   auto B = ExtB->getOperand(0);
   if (A.getValueType() != B.getValueType())
@@ -21840,13 +21848,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
       !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
     return SDValue();
 
-  bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-  bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
-  if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
-    return SDValue();
-
   // If the extensions are mixed, we should lower it to a usdot instead
   if (AIsZExt != BIsZExt) {
     if (!Subtarget->hasMatMulInt8())

>From eb62a5d0c7b99f1de289256cfad5bdd8ef32769e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Wed, 11 Sep 2024 16:43:42 +0100
Subject: [PATCH 3/8] Add USDOT ISD node

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 +++++----------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h   |  3 ++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.td     |  4 ++++
 llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td  |  1 +
 4 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c2899aa6aeac9c..696883bff5fd7d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2701,6 +2701,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SADDLV)
     MAKE_CASE(AArch64ISD::SDOT)
     MAKE_CASE(AArch64ISD::UDOT)
+    MAKE_CASE(AArch64ISD::USDOT)
     MAKE_CASE(AArch64ISD::SMINV)
     MAKE_CASE(AArch64ISD::UMINV)
     MAKE_CASE(AArch64ISD::SMAXV)
@@ -21849,28 +21850,22 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     return SDValue();
 
   // If the extensions are mixed, we should lower it to a usdot instead
+  unsigned Opcode = 0;
   if (AIsZExt != BIsZExt) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
-    bool Scalable = N->getValueType(0).isScalableVT();
 
+    bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32)
       return SDValue();
 
-    unsigned IntrinsicID =
-        Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot;
+    Opcode = AArch64ISD::USDOT;
     // USDOT expects the first operand to be unsigned, so swap the operands if
     // the first is signed and the second is unsigned
     if (AIsSExt && BIsZExt)
       std::swap(A, B);
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ReducedType,
-                       DAG.getConstant(IntrinsicID, DL, MVT::i64), NarrowOp, A,
-                       B);
-  }
-
-  unsigned Opcode = 0;
-  if (AIsSExt)
+  } else if (AIsSExt)
     Opcode = AArch64ISD::SDOT;
   else if (AIsZExt)
     Opcode = AArch64ISD::UDOT;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index f9d45b02d30e30..e79b41b66d77ed 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -280,9 +280,10 @@ enum NodeType : unsigned {
   SADDLP,
   UADDLP,
 
-  // udot/sdot instructions
+  // udot/sdot/usdot instructions
   UDOT,
   SDOT,
+  USDOT,
 
   // Vector across-lanes min/max
   // Only the lower result lane is defined.
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ccef85bfaa8afc..76dc39d309fe99 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -855,6 +855,7 @@ def AArch64frsqrts  : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>;
 
 def AArch64sdot     : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>;
 def AArch64udot     : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>;
+def AArch64usdot    : SDNode<"AArch64ISD::USDOT", SDT_AArch64Dot>;
 
 def AArch64saddv    : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
 def AArch64uaddv    : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>;
@@ -1420,6 +1421,9 @@ def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>;
 defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>;
 defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>;
 
+def : Pat<(v4i32 (AArch64usdot (v4i32 V128:$Rd), (v16i8 V128:$Rm), (v16i8 V128:$Rn))), (USDOTv16i8 $Rd, $Rm, $Rn)>;
+def : Pat<(v2i32 (AArch64usdot (v2i32 V64:$Rd), (v8i8 V64:$Rm), (v8i8 V64:$Rn))), (USDOTv8i8 $Rd, $Rm, $Rn)>;
+
 // sudot lane has a pattern where usdot is expected (there is no sudot).
 // The second operand is used in the dup operation to repeat the indexed
 // element.
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 692cd66d38437d..47db580759d293 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3408,6 +3408,7 @@ let Predicates = [HasSVEorSME, HasMatMulInt8] in {
   defm USDOT_ZZZ  : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>;
   defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>;
   defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>;
+  def : Pat<(nxv4i32 (AArch64usdot (nxv4i32 ZPR32:$Rd), (nxv16i8 ZPR8:$Rm), (nxv16i8 ZPR8:$Rn))), (USDOT_ZZZ $Rd, $Rm, $Rn)>;
 } // End HasSVEorSME, HasMatMulInt8
 
 let Predicates = [HasSVE, HasMatMulFP32] in {

>From 15117fe644edae1423372fcb3040e24ea94b0a4d Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 12 Sep 2024 11:00:46 +0100
Subject: [PATCH 4/8] Match in ISelLowering

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +++++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td     | 5 +----
 llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td  | 3 +--
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 696883bff5fd7d..b340b21dd28b91 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6115,6 +6115,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
                        Op.getOperand(2), Op.getOperand(3));
   }
+  case Intrinsic::aarch64_neon_usdot:
+  case Intrinsic::aarch64_sve_usdot: {
+    return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(),
+                       Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
+  }
   case Intrinsic::get_active_lane_mask: {
     SDValue ID =
         DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 76dc39d309fe99..f8569e7f65c30a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1418,12 +1418,9 @@ let Predicates = [HasMatMulInt8] in {
 def  SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>;
 def  UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>;
 def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>;
-defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>;
+defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>;
 defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>;
 
-def : Pat<(v4i32 (AArch64usdot (v4i32 V128:$Rd), (v16i8 V128:$Rm), (v16i8 V128:$Rn))), (USDOTv16i8 $Rd, $Rm, $Rn)>;
-def : Pat<(v2i32 (AArch64usdot (v2i32 V64:$Rd), (v8i8 V64:$Rm), (v8i8 V64:$Rn))), (USDOTv8i8 $Rd, $Rm, $Rn)>;
-
 // sudot lane has a pattern where usdot is expected (there is no sudot).
 // The second operand is used in the dup operation to repeat the indexed
 // element.
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 47db580759d293..c4207dd478594f 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3405,10 +3405,9 @@ let Predicates = [HasSVE, HasMatMulInt8] in {
 } // End HasSVE, HasMatMulInt8
 
 let Predicates = [HasSVEorSME, HasMatMulInt8] in {
-  defm USDOT_ZZZ  : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>;
+  defm USDOT_ZZZ  : sve_int_dot_mixed<"usdot", AArch64usdot>;
   defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>;
   defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>;
-  def : Pat<(nxv4i32 (AArch64usdot (nxv4i32 ZPR32:$Rd), (nxv16i8 ZPR8:$Rm), (nxv16i8 ZPR8:$Rn))), (USDOT_ZZZ $Rd, $Rm, $Rn)>;
 } // End HasSVEorSME, HasMatMulInt8
 
 let Predicates = [HasSVE, HasMatMulFP32] in {

>From 7268fe171d0f7b46f7a05a7aefc7d338fee64e9e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 16:41:32 +0100
Subject: [PATCH 5/8] Improve tests

---
 .../neon-partial-reduce-dot-product.ll        | 160 +++++++-----------
 .../AArch64/sve-partial-reduce-dot-product.ll | 108 ++++++------
 2 files changed, 115 insertions(+), 153 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 7b6c01f4691175..fead3689f6451b 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT,CHECK-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
 
 define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-DOT-LABEL: udot:
@@ -19,11 +19,6 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
-;
-; CHECK-NOIMM8-LABEL: udot:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    udot v0.4s, v2.16b, v1.16b
-; CHECK-NOIMM8-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -51,11 +46,6 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
-;
-; CHECK-NOIMM8-LABEL: udot_narrow:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    udot v0.2s, v2.8b, v1.8b
-; CHECK-NOIMM8-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -79,11 +69,6 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
 ; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
-;
-; CHECK-NOIMM8-LABEL: sdot:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    sdot v0.4s, v2.16b, v1.16b
-; CHECK-NOIMM8-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -111,11 +96,6 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
 ; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
-;
-; CHECK-NOIMM8-LABEL: sdot_narrow:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    sdot v0.2s, v2.8b, v1.8b
-; CHECK-NOIMM8-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -124,11 +104,6 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 }
 
 define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
-; CHECK-DOT-LABEL: usdot:
-; CHECK-DOT:       // %bb.0:
-; CHECK-DOT-NEXT:    usdot v0.4s, v1.16b, v2.16b
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: usdot:
 ; CHECK-NODOT:       // %bb.0:
 ; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
@@ -142,18 +117,18 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: usdot:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NOIMM8-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NOIMM8-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NOIMM8-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NOIMM8-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NOIMM8-NEXT:    smull v5.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NOIMM8-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NOIMM8-NEXT:    add v0.4s, v5.4s, v0.4s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: usdot:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NOI8MM-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOI8MM-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOI8MM-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NOI8MM-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -162,11 +137,6 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 }
 
 define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
-; CHECK-DOT-LABEL: usdot_narrow:
-; CHECK-DOT:       // %bb.0:
-; CHECK-DOT-NEXT:    usdot v0.2s, v1.8b, v2.8b
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: usdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
 ; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
@@ -184,22 +154,22 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: usdot_narrow:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NOIMM8-NEXT:    sshll v2.8h, v2.8b, #0
-; CHECK-NOIMM8-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NOIMM8-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NOIMM8-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NOIMM8-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NOIMM8-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NOIMM8-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NOIMM8-NEXT:    smlal v3.4s, v6.4h, v5.4h
-; CHECK-NOIMM8-NEXT:    add v0.2s, v1.2s, v0.2s
-; CHECK-NOIMM8-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: usdot_narrow:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NOI8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOI8MM-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOI8MM-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOI8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOI8MM-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOI8MM-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NOI8MM-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NOI8MM-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -208,11 +178,6 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 }
 
 define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
-; CHECK-DOT-LABEL: sudot:
-; CHECK-DOT:       // %bb.0:
-; CHECK-DOT-NEXT:    usdot v0.4s, v2.16b, v1.16b
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: sudot:
 ; CHECK-NODOT:       // %bb.0:
 ; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
@@ -226,18 +191,18 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: sudot:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NOIMM8-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NOIMM8-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NOIMM8-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NOIMM8-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NOIMM8-NEXT:    smull v5.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NOIMM8-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NOIMM8-NEXT:    add v0.4s, v5.4s, v0.4s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: sudot:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NOI8MM-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOI8MM-NEXT:    smull v5.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOI8MM-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NOI8MM-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -246,11 +211,6 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 }
 
 define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
-; CHECK-DOT-LABEL: sudot_narrow:
-; CHECK-DOT:       // %bb.0:
-; CHECK-DOT-NEXT:    usdot v0.2s, v2.8b, v1.8b
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: sudot_narrow:
 ; CHECK-NODOT:       // %bb.0:
 ; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
@@ -268,22 +228,22 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: sudot_narrow:
-; CHECK-NOIMM8:       // %bb.0:
-; CHECK-NOIMM8-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NOIMM8-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NOIMM8-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NOIMM8-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NOIMM8-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NOIMM8-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NOIMM8-NEXT:    smlal v0.4s, v2.4h, v1.4h
-; CHECK-NOIMM8-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NOIMM8-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NOIMM8-NEXT:    smlal v3.4s, v6.4h, v5.4h
-; CHECK-NOIMM8-NEXT:    add v0.2s, v1.2s, v0.2s
-; CHECK-NOIMM8-NEXT:    add v0.2s, v3.2s, v0.2s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: sudot_narrow:
+; CHECK-NOI8MM:       // %bb.0:
+; CHECK-NOI8MM-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NOI8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOI8MM-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOI8MM-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOI8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOI8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOI8MM-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOI8MM-NEXT:    smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NOI8MM-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NOI8MM-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -322,3 +282,5 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
   %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
   ret <2 x i32> %partial.reduce
 }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK-I8MM: {{.*}}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 35d2b8ca30a041..00e5ac479d02c9 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
 
 define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: udot:
@@ -55,33 +55,33 @@ entry:
 }
 
 define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-DOT-LABEL: usdot:
-; CHECK-DOT:       // %bb.0: // %entry
-; CHECK-DOT-NEXT:    usdot z0.s, z1.b, z2.b
-; CHECK-DOT-NEXT:    ret
+; CHECK-I8MM-LABEL: usdot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    usdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: usdot:
-; CHECK-NOIMM8:       // %bb.0: // %entry
-; CHECK-NOIMM8-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NOIMM8-NEXT:    sunpklo z4.h, z2.b
-; CHECK-NOIMM8-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NOIMM8-NEXT:    sunpkhi z2.h, z2.b
-; CHECK-NOIMM8-NEXT:    ptrue p0.s
-; CHECK-NOIMM8-NEXT:    uunpklo z5.s, z3.h
-; CHECK-NOIMM8-NEXT:    uunpkhi z3.s, z3.h
-; CHECK-NOIMM8-NEXT:    sunpklo z6.s, z4.h
-; CHECK-NOIMM8-NEXT:    sunpkhi z4.s, z4.h
-; CHECK-NOIMM8-NEXT:    uunpklo z7.s, z1.h
-; CHECK-NOIMM8-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NOIMM8-NEXT:    sunpklo z24.s, z2.h
-; CHECK-NOIMM8-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NOIMM8-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NOIMM8-NEXT:    movprfx z1, z3
-; CHECK-NOIMM8-NEXT:    mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NOIMM8-NEXT:    add z0.s, z1.s, z0.s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: usdot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    uunpklo z3.h, z1.b
+; CHECK-NOI8MM-NEXT:    sunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT:    uunpkhi z1.h, z1.b
+; CHECK-NOI8MM-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT:    ptrue p0.s
+; CHECK-NOI8MM-NEXT:    uunpklo z5.s, z3.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT:    sunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT:    uunpklo z7.s, z1.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NOI8MM-NEXT:    sunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOI8MM-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NOI8MM-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOI8MM-NEXT:    movprfx z1, z3
+; CHECK-NOI8MM-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOI8MM-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NOI8MM-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -91,33 +91,33 @@ entry:
 }
 
 define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-DOT-LABEL: sudot:
-; CHECK-DOT:       // %bb.0: // %entry
-; CHECK-DOT-NEXT:    usdot z0.s, z2.b, z1.b
-; CHECK-DOT-NEXT:    ret
+; CHECK-I8MM-LABEL: sudot:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    usdot z0.s, z2.b, z1.b
+; CHECK-I8MM-NEXT:    ret
 ;
-; CHECK-NOIMM8-LABEL: sudot:
-; CHECK-NOIMM8:       // %bb.0: // %entry
-; CHECK-NOIMM8-NEXT:    sunpklo z3.h, z1.b
-; CHECK-NOIMM8-NEXT:    uunpklo z4.h, z2.b
-; CHECK-NOIMM8-NEXT:    sunpkhi z1.h, z1.b
-; CHECK-NOIMM8-NEXT:    uunpkhi z2.h, z2.b
-; CHECK-NOIMM8-NEXT:    ptrue p0.s
-; CHECK-NOIMM8-NEXT:    sunpklo z5.s, z3.h
-; CHECK-NOIMM8-NEXT:    sunpkhi z3.s, z3.h
-; CHECK-NOIMM8-NEXT:    uunpklo z6.s, z4.h
-; CHECK-NOIMM8-NEXT:    uunpkhi z4.s, z4.h
-; CHECK-NOIMM8-NEXT:    sunpklo z7.s, z1.h
-; CHECK-NOIMM8-NEXT:    sunpkhi z1.s, z1.h
-; CHECK-NOIMM8-NEXT:    uunpklo z24.s, z2.h
-; CHECK-NOIMM8-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NOIMM8-NEXT:    mul z3.s, z3.s, z4.s
-; CHECK-NOIMM8-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NOIMM8-NEXT:    movprfx z1, z3
-; CHECK-NOIMM8-NEXT:    mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NOIMM8-NEXT:    add z0.s, z1.s, z0.s
-; CHECK-NOIMM8-NEXT:    ret
+; CHECK-NOI8MM-LABEL: sudot:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    sunpklo z3.h, z1.b
+; CHECK-NOI8MM-NEXT:    uunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT:    sunpkhi z1.h, z1.b
+; CHECK-NOI8MM-NEXT:    uunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT:    ptrue p0.s
+; CHECK-NOI8MM-NEXT:    sunpklo z5.s, z3.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT:    uunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT:    sunpklo z7.s, z1.h
+; CHECK-NOI8MM-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NOI8MM-NEXT:    uunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT:    mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOI8MM-NEXT:    mul z3.s, z3.s, z4.s
+; CHECK-NOI8MM-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOI8MM-NEXT:    movprfx z1, z3
+; CHECK-NOI8MM-NEXT:    mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOI8MM-NEXT:    add z0.s, z1.s, z0.s
+; CHECK-NOI8MM-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>

>From 9dedb73cca8f2e7bb268e6c7e7e2acea3f103e8a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 16:48:58 +0100
Subject: [PATCH 6/8] Simplify BIsZext check

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b340b21dd28b91..6bb73d1beb185a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21866,9 +21866,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
       return SDValue();
 
     Opcode = AArch64ISD::USDOT;
-    // USDOT expects the first operand to be unsigned, so swap the operands if
-    // the first is signed and the second is unsigned
-    if (AIsSExt && BIsZExt)
+    // USDOT expects the signed operand to be last
+    if (BIsZExt)
       std::swap(A, B);
   } else if (AIsSExt)
     Opcode = AArch64ISD::SDOT;

>From 15988da2bf835ea62a0b2ed779edeb69240cad98 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 16 Sep 2024 17:06:09 +0100
Subject: [PATCH 7/8] Simplify ext opcode check

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6bb73d1beb185a..3f79bc2a220a70 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21831,12 +21831,9 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   auto ExtA = MulOp->getOperand(0);
   auto ExtB = MulOp->getOperand(1);
 
-  bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-  bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
-  if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
-    return SDValue();
+  if (!ISD::isExtOpcode(ExtA->getOpcode()) || !ISD::isExtOpcode(ExtB->getOpcode())) return SDValue();
+  bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
 
   auto A = ExtA->getOperand(0);
   auto B = ExtB->getOperand(0);
@@ -21856,7 +21853,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   // If the extensions are mixed, we should lower it to a usdot instead
   unsigned Opcode = 0;
-  if (AIsZExt != BIsZExt) {
+  if (AIsSigned != BIsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
@@ -21867,11 +21864,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
     Opcode = AArch64ISD::USDOT;
     // USDOT expects the signed operand to be last
-    if (BIsZExt)
+    if (!BIsSigned)
       std::swap(A, B);
-  } else if (AIsSExt)
+  } else if (AIsSigned)
     Opcode = AArch64ISD::SDOT;
-  else if (AIsZExt)
+  else if (!AIsSigned)
     Opcode = AArch64ISD::UDOT;
 
   assert(Opcode != 0 && "Unexpected dot product case encountered.");

>From bb7c880147dcd0378b4808ac084512a9ebb525d6 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 17 Sep 2024 10:30:19 +0100
Subject: [PATCH 8/8] Fix usdot lane matching

---
 llvm/lib/Target/AArch64/AArch64InstrInfo.td |  2 +-
 llvm/test/CodeGen/AArch64/aarch64-matmul.ll | 82 ++++++++++++++-------
 2 files changed, 57 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f8569e7f65c30a..0c2a797d4cce62 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1419,7 +1419,7 @@ def  SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>;
 def  UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>;
 def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>;
 defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>;
-defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>;
+defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", AArch64usdot>;
 
 // sudot lane has a pattern where usdot is expected (there is no sudot).
 // The second operand is used in the dup operation to repeat the indexed
diff --git a/llvm/test/CodeGen/AArch64/aarch64-matmul.ll b/llvm/test/CodeGen/AArch64/aarch64-matmul.ll
index 649d0a9bfcab47..52ffbba49ecb49 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-matmul.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-matmul.ll
@@ -1,41 +1,52 @@
 ; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+neon,+i8mm < %s -o -| FileCheck %s
 
 define <4 x i32> @smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: smmla.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    smmla v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: smmla.v4i32.v16i8
-; CHECK: smmla   v0.4s, v1.16b, v2.16b
   %vmmla1.i = tail call <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
   ret <4 x i32> %vmmla1.i
 }
 
 define <4 x i32> @ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: ummla.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ummla v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: ummla.v4i32.v16i8
-; CHECK: ummla   v0.4s, v1.16b, v2.16b
   %vmmla1.i = tail call <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
   ret <4 x i32> %vmmla1.i
 }
 
 define <4 x i32> @usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: usmmla.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    usmmla v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usmmla.v4i32.v16i8
-; CHECK: usmmla   v0.4s, v1.16b, v2.16b
   %vusmmla1.i = tail call <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
   ret <4 x i32> %vusmmla1.i
 }
 
 define <2 x i32> @usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: usdot.v2i32.v8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    usdot v0.2s, v1.8b, v2.8b
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot.v2i32.v8i8
-; CHECK: usdot   v0.2s, v1.8b, v2.8b
   %vusdot1.i = tail call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b)
   ret <2 x i32> %vusdot1.i
 }
 
 define <2 x i32> @usdot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: usdot_lane.v2i32.v8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d2 killed $d2 def $q2
+; CHECK-NEXT:    usdot v0.2s, v1.8b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot_lane.v2i32.v8i8
-; CHECK: usdot   v0.2s, v1.8b, v2.4b[0]
   %0 = bitcast <8 x i8> %b to <2 x i32>
   %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
   %1 = bitcast <2 x i32> %shuffle to <8 x i8>
@@ -44,9 +55,12 @@ entry:
 }
 
 define <2 x i32> @sudot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sudot_lane.v2i32.v8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d2 killed $d2 def $q2
+; CHECK-NEXT:    sudot v0.2s, v1.8b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: sudot_lane.v2i32.v8i8
-; CHECK: sudot   v0.2s, v1.8b, v2.4b[0]
   %0 = bitcast <8 x i8> %b to <2 x i32>
   %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
   %1 = bitcast <2 x i32> %shuffle to <8 x i8>
@@ -55,9 +69,11 @@ entry:
 }
 
 define <2 x i32> @usdot_lane.v2i32.v16i8(<2 x i32> %r, <8 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: usdot_lane.v2i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    usdot v0.2s, v1.8b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot_lane.v2i32.v16i8
-; CHECK: usdot   v0.2s, v1.8b, v2.4b[0]
   %0 = bitcast <16 x i8> %b to <4 x i32>
   %shuffle = shufflevector <4 x i32> %0, <4 x i32> undef, <2 x i32> zeroinitializer
   %1 = bitcast <2 x i32> %shuffle to <8 x i8>
@@ -66,9 +82,11 @@ entry:
 }
 
 define <2 x i32> @sudot_lane.v2i32.v16i8(<2 x i32> %r, <8 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sudot_lane.v2i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sudot v0.2s, v1.8b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: sudot_lane.v2i32.v16i8
-; CHECK: sudot   v0.2s, v1.8b, v2.4b[0]
   %0 = bitcast <16 x i8> %b to <4 x i32>
   %shuffle = shufflevector <4 x i32> %0, <4 x i32> undef, <2 x i32> zeroinitializer
   %1 = bitcast <2 x i32> %shuffle to <8 x i8>
@@ -77,17 +95,22 @@ entry:
 }
 
 define <4 x i32> @usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: usdot.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot.v4i32.v16i8
-; CHECK: usdot   v0.4s, v1.16b, v2.16b
   %vusdot1.i = tail call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
   ret <4 x i32> %vusdot1.i
 }
 
 define <4 x i32> @usdot_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: usdot_lane.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d2 killed $d2 def $q2
+; CHECK-NEXT:    usdot v0.4s, v1.16b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot_lane.v4i32.v16i8
-; CHECK: usdot   v0.4s, v1.16b, v2.4b[0]
   %0 = bitcast <8 x i8> %b to <2 x i32>
   %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
   %1 = bitcast <4 x i32> %shuffle to <16 x i8>
@@ -96,9 +119,12 @@ entry:
 }
 
 define <4 x i32> @sudot_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sudot_lane.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d2 killed $d2 def $q2
+; CHECK-NEXT:    sudot v0.4s, v1.16b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: sudot_lane.v4i32.v16i8
-; CHECK: sudot   v0.4s, v1.16b, v2.4b[0]
   %0 = bitcast <8 x i8> %b to <2 x i32>
   %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
   %1 = bitcast <4 x i32> %shuffle to <16 x i8>
@@ -107,9 +133,11 @@ entry:
 }
 
 define <4 x i32> @usdot_laneq.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: usdot_laneq.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    usdot v0.4s, v1.16b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: usdot_laneq.v4i32.v16i8
-; CHECK: usdot   v0.4s, v1.16b, v2.4b[0]
   %0 = bitcast <16 x i8> %b to <4 x i32>
   %shuffle = shufflevector <4 x i32> %0, <4 x i32> undef, <4 x i32> zeroinitializer
   %1 = bitcast <4 x i32> %shuffle to <16 x i8>
@@ -118,9 +146,11 @@ entry:
 }
 
 define <4 x i32> @sudot_laneq.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sudot_laneq.v4i32.v16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sudot v0.4s, v1.16b, v2.4b[0]
+; CHECK-NEXT:    ret
 entry:
-; CHECK-LABEL: sudot_laneq.v4i32.v16i8
-; CHECK: sudot   v0.4s, v1.16b, v2.4b[0]
   %0 = bitcast <16 x i8> %b to <4 x i32>
   %shuffle = shufflevector <4 x i32> %0, <4 x i32> undef, <4 x i32> zeroinitializer
   %1 = bitcast <4 x i32> %shuffle to <16 x i8>



More information about the llvm-commits mailing list