[llvm] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (PR #107566)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 6 04:25:21 PDT 2024
https://github.com/SamTebbs33 created https://github.com/llvm/llvm-project/pull/107566
This PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic.
>From ac88857b22b24a59b19d88dc38ddf5c11be5454d Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 3 Sep 2024 11:06:25 +0100
Subject: [PATCH] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial
reductions to usdot
This PR adds lowering for partial reductions of a mix of sign/zero
extended inputs to the usdot intrinsic.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 62 ++++--
.../neon-partial-reduce-dot-product.ll | 195 +++++++++++++++++-
.../AArch64/sve-partial-reduce-dot-product.ll | 161 +++++++++++++--
3 files changed, 382 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d80087336d230..a3b372e677f98a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21824,37 +21824,59 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
- bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
- return SDValue();
-
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();
- unsigned Opcode = 0;
-
- if (IsSExt)
- Opcode = AArch64ISD::SDOT;
- else if (IsZExt)
- Opcode = AArch64ISD::UDOT;
-
- assert(Opcode != 0 && "Unexpected dot product case encountered.");
-
EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
- (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
- (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ return SDValue();
- return SDValue();
+ bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+ if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
+ return SDValue();
+
+ // If the extensions are mixed, we should lower it to a usdot instead
+ if (AIsZExt != BIsZExt) {
+ if (!Subtarget->hasMatMulInt8())
+ return SDValue();
+ bool Scalable = N->getValueType(0).isScalableVT();
+
+ // There's no nxv2i64 version of usdot
+ if (Scalable && ReducedType != MVT::nxv4i32)
+ return SDValue();
+
+ unsigned IntrinsicID =
+ Scalable ? Intrinsic::aarch64_sve_usdot : Intrinsic::aarch64_neon_usdot;
+ // USDOT expects the first operand to be unsigned, so swap the operands if
+ // the first is signed and the second is unsigned
+ if (AIsSExt && BIsZExt)
+ std::swap(A, B);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ReducedType,
+ DAG.getConstant(IntrinsicID, DL, MVT::i64), NarrowOp, A,
+ B);
+ }
+
+ unsigned Opcode = 0;
+ if (AIsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (AIsZExt)
+ Opcode = AArch64ISD::UDOT;
+
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+ return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 8035504d5558b1..7b6c01f4691175 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
-; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
@@ -18,6 +19,11 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: udot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: udot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -45,6 +51,11 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: udot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: udot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -68,6 +79,11 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sdot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -95,6 +111,11 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sdot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sdot v0.2s, v2.8b, v1.8b
+; CHECK-NOIMM8-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -102,7 +123,175 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
ret <2 x i32> %partial.reduce
}
-define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
+define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: usdot:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = sext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: usdot_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: usdot_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = zext <8 x i8> %u to <8 x i32>
+ %s.wide = sext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.4s, v2.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sudot:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NOIMM8-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NOIMM8-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NOIMM8-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = sext <16 x i8> %u to <16 x i32>
+ %s.wide = zext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-DOT-LABEL: sudot_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: usdot v0.2s, v2.8b, v1.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sudot_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot_narrow:
+; CHECK-NOIMM8: // %bb.0:
+; CHECK-NOIMM8-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NOIMM8-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NOIMM8-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NOIMM8-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NOIMM8-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NOIMM8-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NOIMM8-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NOIMM8-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NOIMM8-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NOIMM8-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NOIMM8-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NOIMM8-NEXT: ret
+ %u.wide = sext <8 x i8> %u to <8 x i32>
+ %s.wide = zext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index b1354ab210f727..35d2b8ca30a041 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,8 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOIMM8
-define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp:
+define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
@@ -14,8 +15,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide:
+define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: udot_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
@@ -27,8 +28,8 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_sext:
+define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: sdot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
@@ -40,8 +41,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: dotp_wide_sext:
+define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: sdot_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
@@ -53,8 +54,80 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
-; CHECK-LABEL: not_dotp:
+define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: usdot:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: usdot z0.s, z1.b, z2.b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: usdot:
+; CHECK-NOIMM8: // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT: uunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT: sunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT: ptrue p0.s
+; CHECK-NOIMM8-NEXT: uunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT: sunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT: uunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT: sunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT: movprfx z1, z3
+; CHECK-NOIMM8-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-DOT-LABEL: sudot:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: usdot z0.s, z2.b, z1.b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NOIMM8-LABEL: sudot:
+; CHECK-NOIMM8: // %bb.0: // %entry
+; CHECK-NOIMM8-NEXT: sunpklo z3.h, z1.b
+; CHECK-NOIMM8-NEXT: uunpklo z4.h, z2.b
+; CHECK-NOIMM8-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NOIMM8-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NOIMM8-NEXT: ptrue p0.s
+; CHECK-NOIMM8-NEXT: sunpklo z5.s, z3.h
+; CHECK-NOIMM8-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NOIMM8-NEXT: uunpklo z6.s, z4.h
+; CHECK-NOIMM8-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NOIMM8-NEXT: sunpklo z7.s, z1.h
+; CHECK-NOIMM8-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NOIMM8-NEXT: uunpklo z24.s, z2.h
+; CHECK-NOIMM8-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NOIMM8-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NOIMM8-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NOIMM8-NEXT: movprfx z1, z3
+; CHECK-NOIMM8-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NOIMM8-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NOIMM8-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z2.h, z2.h, #0xff
@@ -74,8 +147,8 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
-; CHECK-LABEL: not_dotp_wide:
+define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+; CHECK-LABEL: not_udot_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: and z2.s, z2.s, #0xffff
@@ -94,3 +167,65 @@ entry:
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
+
+define <vscale x 2 x i64> @not_usdot(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_usdot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: movprfx z1, z3
+; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_sudot(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_sudot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: movprfx z1, z3
+; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
More information about the llvm-commits
mailing list