[llvm] [AArch64][NEON] Lower fixed-width add partial reductions to dot product (PR #107078)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 05:50:52 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/107078

>From 21f893c14b6c817ff0319511e1296e218d581922 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 3 Sep 2024 11:06:25 +0100
Subject: [PATCH 1/3] [AArch64][NEON] Lower fixed-width add partial reductions
 to dot product

This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial
reductions to a dot product when Neon and the dot product feature are
available.

The work is by Max Beck-Jones (@DevM-uk).
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  16 +-
 .../neon-partial-reduce-dot-product.ll        | 209 ++++++++++++++++++
 2 files changed, 218 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1735ff5cd69748..f3298a326bf4c1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1994,7 +1994,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   EVT VT = EVT::getEVT(I->getType());
-  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
+  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
+      VT != MVT::v2i32;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21781,7 +21782,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
              Intrinsic::experimental_vector_partial_reduce_add &&
          "Expected a partial reduction node");
 
-  if (!Subtarget->isSVEorStreamingSVEAvailable())
+  if (!Subtarget->isSVEorStreamingSVEAvailable() &&
+          !(Subtarget->isNeonAvailable() && Subtarget->hasDotProd()))
     return SDValue();
 
   SDLoc DL(N);
@@ -21818,11 +21820,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
-    return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
-
-  if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
-    return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
+  if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
+      (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
+      (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
+      (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+    return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 
   return SDValue();
 }
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..13b731451b60c1
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -0,0 +1,209 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s
+; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefix=CHECK-NODOTPROD
+
+define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: udot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    udot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: udot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: udot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    udot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: udot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <8 x i8> %u to <8 x i32>
+  %s.wide = zext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: sdot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: sdot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <16 x i8> %u to <16 x i32>
+  %s.wide = sext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: sdot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sdot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: sdot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <8 x i8> %u to <8 x i32>
+  %s.wide = sext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: not_udot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <8 x i8> %u to <8 x i32>
+  %s.wide = zext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) #0{
+; CHECK-LABEL: not_udot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    bic v1.4h, #255, lsl #8
+; CHECK-NEXT:    bic v2.4h, #255, lsl #8
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    bic v1.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT:    bic v2.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <4 x i8> %u to <4 x i32>
+  %s.wide = zext <4 x i8> %s to <4 x i32>
+  %mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_sdot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: not_sdot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_sdot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <8 x i8> %u to <8 x i32>
+  %s.wide = sext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @not_sdot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) #0{
+; CHECK-LABEL: not_sdot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-NEXT:    shl v1.4s, v1.4s, #24
+; CHECK-NEXT:    shl v2.4s, v2.4s, #24
+; CHECK-NEXT:    sshr v1.4s, v1.4s, #24
+; CHECK-NEXT:    sshr v2.4s, v2.4s, #24
+; CHECK-NEXT:    mul v1.4s, v2.4s, v1.4s
+; CHECK-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_sdot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-NODOTPROD-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    shl v1.4s, v1.4s, #24
+; CHECK-NODOTPROD-NEXT:    shl v2.4s, v2.4s, #24
+; CHECK-NODOTPROD-NEXT:    sshr v1.4s, v1.4s, #24
+; CHECK-NODOTPROD-NEXT:    sshr v2.4s, v2.4s, #24
+; CHECK-NODOTPROD-NEXT:    mul v1.4s, v2.4s, v1.4s
+; CHECK-NODOTPROD-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v0.2s, v1.2s
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v2.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <4 x i8> %u to <4 x i32>
+  %s.wide = sext <4 x i8> %s to <4 x i32>
+  %mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}

>From 35e72f2654e149b076c258492aa3a58662709378 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 3 Sep 2024 11:40:57 +0100
Subject: [PATCH 2/3] Fix formatting (git-clang-format crashed before)

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f3298a326bf4c1..82281c4dd25f88 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,7 +1995,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
 
   EVT VT = EVT::getEVT(I->getType());
   return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
-      VT != MVT::v2i32;
+         VT != MVT::v2i32;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21783,7 +21783,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
          "Expected a partial reduction node");
 
   if (!Subtarget->isSVEorStreamingSVEAvailable() &&
-          !(Subtarget->isNeonAvailable() && Subtarget->hasDotProd()))
+      !(Subtarget->isNeonAvailable() && Subtarget->hasDotProd()))
     return SDValue();
 
   SDLoc DL(N);

>From ffa0d4b5961bedbbeb952a8f6211945d57ea3a48 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 4 Sep 2024 13:12:09 +0100
Subject: [PATCH 3/3] Check for dotprod when SVE is available as well

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |   6 +-
 .../AArch64/partial-reduce-dot-product.ll     | 157 +++++++++++++++++-
 2 files changed, 161 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 82281c4dd25f88..87a7b70b24dbd3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21783,7 +21783,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
          "Expected a partial reduction node");
 
   if (!Subtarget->isSVEorStreamingSVEAvailable() &&
-      !(Subtarget->isNeonAvailable() && Subtarget->hasDotProd()))
+      !Subtarget->isNeonAvailable())
+    return SDValue();
+
+  // Fixed-width requires the dotprod feature, both for Neon and SVE
+  if (!N->getValueType(0).isScalableVT() && !Subtarget->hasDotProd())
     return SDValue();
 
   SDLoc DL(N);
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index b1354ab210f727..daf5ec5f681367 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,11 +1,17 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+dotprod %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefix=CHECK-NODOTPROD
 
 define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: dotp:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -19,6 +25,11 @@ define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_wide:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -32,6 +43,11 @@ define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_sext:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -45,6 +61,11 @@ define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_wide_sext:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -53,6 +74,114 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
+define <4 x i32> @dotp_fixed(<4 x i32> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: dotp_fixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_fixed:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    umull v3.8h, v1.8b, v2.8b
+; CHECK-NODOTPROD-NEXT:    umull2 v1.8h, v1.16b, v2.16b
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+entry:
+  %a.wide = zext <16 x i8> %a to <16 x i32>
+  %b.wide = zext <16 x i8> %b to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i64> @dotp_fixed_wide(<2 x i64> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dotp_fixed_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    umull v3.4s, v1.4h, v2.4h
+; CHECK-NEXT:    umull2 v1.4s, v1.8h, v2.8h
+; CHECK-NEXT:    ushll v2.2d, v1.2s, #0
+; CHECK-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NEXT:    uaddw2 v2.2d, v2.2d, v3.4s
+; CHECK-NEXT:    uaddw2 v0.2d, v0.2d, v1.4s
+; CHECK-NEXT:    add v0.2d, v2.2d, v0.2d
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_fixed_wide:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    umull v3.4s, v1.4h, v2.4h
+; CHECK-NODOTPROD-NEXT:    umull2 v1.4s, v1.8h, v2.8h
+; CHECK-NODOTPROD-NEXT:    ushll v2.2d, v1.2s, #0
+; CHECK-NODOTPROD-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOTPROD-NEXT:    uaddw2 v2.2d, v2.2d, v3.4s
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.2d, v0.2d, v1.4s
+; CHECK-NODOTPROD-NEXT:    add v0.2d, v2.2d, v0.2d
+; CHECK-NODOTPROD-NEXT:    ret
+entry:
+  %a.wide = zext <8 x i16> %a to <8 x i64>
+  %b.wide = zext <8 x i16> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <4 x i32> @dotp_fixed_sext(<4 x i32> %accc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: dotp_fixed_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sdot v0.4s, v1.16b, v2.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_fixed_sext:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    smull v3.8h, v1.8b, v2.8b
+; CHECK-NODOTPROD-NEXT:    smull2 v1.8h, v1.16b, v2.16b
+; CHECK-NODOTPROD-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+entry:
+  %a.wide = sext <16 x i8> %a to <16 x i32>
+  %b.wide = sext <16 x i8> %b to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %accc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i64> @dotp_fixed_wide_sext(<2 x i64> %acc, <8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dotp_fixed_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    smull v3.4s, v1.4h, v2.4h
+; CHECK-NEXT:    smull2 v1.4s, v1.8h, v2.8h
+; CHECK-NEXT:    sshll v2.2d, v1.2s, #0
+; CHECK-NEXT:    saddw v0.2d, v0.2d, v3.2s
+; CHECK-NEXT:    saddw2 v2.2d, v2.2d, v3.4s
+; CHECK-NEXT:    saddw2 v0.2d, v0.2d, v1.4s
+; CHECK-NEXT:    add v0.2d, v2.2d, v0.2d
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: dotp_fixed_wide_sext:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    smull v3.4s, v1.4h, v2.4h
+; CHECK-NODOTPROD-NEXT:    smull2 v1.4s, v1.8h, v2.8h
+; CHECK-NODOTPROD-NEXT:    sshll v2.2d, v1.2s, #0
+; CHECK-NODOTPROD-NEXT:    saddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOTPROD-NEXT:    saddw2 v2.2d, v2.2d, v3.4s
+; CHECK-NODOTPROD-NEXT:    saddw2 v0.2d, v0.2d, v1.4s
+; CHECK-NODOTPROD-NEXT:    add v0.2d, v2.2d, v0.2d
+; CHECK-NODOTPROD-NEXT:    ret
+entry:
+  %a.wide = sext <8 x i16> %a to <8 x i64>
+  %b.wide = sext <8 x i16> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
 define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_dotp:
 ; CHECK:       // %bb.0: // %entry
@@ -66,6 +195,19 @@ define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z4.s
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_dotp:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NODOTPROD-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NODOTPROD-NEXT:    ptrue p0.s
+; CHECK-NODOTPROD-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NODOTPROD-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NODOTPROD-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NODOTPROD-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NODOTPROD-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NODOTPROD-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -87,6 +229,19 @@ define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    mla z0.d, p0/m, z3.d, z4.d
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_dotp_wide:
+; CHECK-NODOTPROD:       // %bb.0: // %entry
+; CHECK-NODOTPROD-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NODOTPROD-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NODOTPROD-NEXT:    ptrue p0.d
+; CHECK-NODOTPROD-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NODOTPROD-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NODOTPROD-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NODOTPROD-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NODOTPROD-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NODOTPROD-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NODOTPROD-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
   %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>



More information about the llvm-commits mailing list