[llvm] [AArch64] Extend vecreduce to udot/sdot transformation to support usdot (PR #120094)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 16 06:54:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Igor Kirillov (igogo-x86)
<details>
<summary>Changes</summary>
---
Patch is 226.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120094.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+27-7)
- (modified) llvm/test/CodeGen/AArch64/neon-dotreduce.ll (+4445-997)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c19265613c706d..3ef6ef356465d3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18283,16 +18283,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
unsigned ExtOpcode = Op0.getOpcode();
SDValue A = Op0;
SDValue B;
+ unsigned DotOpcode;
if (ExtOpcode == ISD::MUL) {
A = Op0.getOperand(0);
B = Op0.getOperand(1);
- if (A.getOpcode() != B.getOpcode() ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
+ if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
return SDValue();
- ExtOpcode = A.getOpcode();
- }
- if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
+ auto OpCodeA = A.getOpcode();
+ if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ auto OpCodeB = B.getOpcode();
+ if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ if (OpCodeA == OpCodeB) {
+ DotOpcode =
+ OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
+ } else {
+ // Check USDOT support support
+ if (!ST->hasMatMulInt8())
+ return SDValue();
+ DotOpcode = AArch64ISD::USDOT;
+ if (OpCodeA == ISD::SIGN_EXTEND)
+ std::swap(A, B);
+ }
+ } else if (ExtOpcode == ISD::ZERO_EXTEND) {
+ DotOpcode = AArch64ISD::UDOT;
+ } else if (ExtOpcode == ISD::SIGN_EXTEND) {
+ DotOpcode = AArch64ISD::SDOT;
+ } else {
return SDValue();
+ }
EVT Op0VT = A.getOperand(0).getValueType();
bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18318,8 +18340,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
TargetType = MVT::v2i32;
}
- auto DotOpcode =
- (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
// Handle the case where we need to generate only one Dot operation.
if (NumOfVecReduce == 1) {
SDValue Zeros = DAG.getConstant(0, DL, TargetType);
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index c345c1e50bbbb7..05ac2956da00c7 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1,22 +1,28 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
; CHECK-GI: warning: Instruction selection used fallback path for test_udot_v5i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v5i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v5i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v5i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v25i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v25i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v25i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v25i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v33i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v33i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v33i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v33i8_double
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
declare i32 @llvm.vector.reduce.add.v5i32(<5 x i32>)
@@ -290,6 +296,128 @@ entry:
ret i32 %x
}
+define i32 @test_usdot_v4i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v4i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: ldr s0, [x0]
+; CHECK-SD-NEXT: ldr s1, [x1]
+; CHECK-SD-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-SD-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-SD-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr w8, [x0]
+; CHECK-GI-NEXT: ldr w9, [x1]
+; CHECK-GI-NEXT: fmov s0, w8
+; CHECK-GI-NEXT: fmov s2, w9
+; CHECK-GI-NEXT: uxtb w8, w8
+; CHECK-GI-NEXT: sxtb w9, w9
+; CHECK-GI-NEXT: mov b1, v0.b[1]
+; CHECK-GI-NEXT: mov b3, v0.b[2]
+; CHECK-GI-NEXT: mov b5, v2.b[2]
+; CHECK-GI-NEXT: mov b4, v0.b[3]
+; CHECK-GI-NEXT: mov b0, v2.b[1]
+; CHECK-GI-NEXT: mov b6, v2.b[3]
+; CHECK-GI-NEXT: fmov s2, w9
+; CHECK-GI-NEXT: fmov w10, s1
+; CHECK-GI-NEXT: fmov w11, s3
+; CHECK-GI-NEXT: fmov s1, w8
+; CHECK-GI-NEXT: fmov w13, s5
+; CHECK-GI-NEXT: fmov w8, s4
+; CHECK-GI-NEXT: fmov w12, s0
+; CHECK-GI-NEXT: uxtb w10, w10
+; CHECK-GI-NEXT: uxtb w11, w11
+; CHECK-GI-NEXT: sxtb w13, w13
+; CHECK-GI-NEXT: uxtb w8, w8
+; CHECK-GI-NEXT: sxtb w12, w12
+; CHECK-GI-NEXT: mov v1.h[1], w10
+; CHECK-GI-NEXT: fmov w10, s6
+; CHECK-GI-NEXT: fmov s0, w11
+; CHECK-GI-NEXT: fmov s3, w13
+; CHECK-GI-NEXT: mov v2.h[1], w12
+; CHECK-GI-NEXT: sxtb w10, w10
+; CHECK-GI-NEXT: mov v0.h[1], w8
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mov v3.h[1], w10
+; CHECK-GI-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: mov v1.d[1], v0.d[0]
+; CHECK-GI-NEXT: mov v2.d[1], v3.d[0]
+; CHECK-GI-NEXT: mul v0.4s, v2.4s, v1.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <4 x i8>, ptr %a
+ %1 = zext <4 x i8> %0 to <4 x i32>
+ %2 = load <4 x i8>, ptr %b
+ %3 = sext <4 x i8> %2 to <4 x i32>
+ %4 = mul nsw <4 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_v4i8_double(<4 x i8> %a, <4 x i8> %b, <4 x i8> %c, <4 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v4i8_double:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-SD-NEXT: bic v2.4h, #255, lsl #8
+; CHECK-SD-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-SD-NEXT: bic v0.4h, #255, lsl #8
+; CHECK-SD-NEXT: shl v3.4s, v3.4s, #24
+; CHECK-SD-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-SD-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-SD-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-SD-NEXT: sshr v3.4s, v3.4s, #24
+; CHECK-SD-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-SD-NEXT: mul v2.4s, v2.4s, v3.4s
+; CHECK-SD-NEXT: mla v2.4s, v0.4s, v1.4s
+; CHECK-SD-NEXT: addv s0, v2.4s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8_double:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: movi v4.2d, #0x0000ff000000ff
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-GI-NEXT: shl v3.4s, v3.4s, #24
+; CHECK-GI-NEXT: and v0.16b, v0.16b, v4.16b
+; CHECK-GI-NEXT: and v2.16b, v2.16b, v4.16b
+; CHECK-GI-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-GI-NEXT: sshr v3.4s, v3.4s, #24
+; CHECK-GI-NEXT: mul v0.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT: mul v1.4s, v2.4s, v3.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: addv s1, v1.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: fmov w9, s1
+; CHECK-GI-NEXT: add w0, w8, w9
+; CHECK-GI-NEXT: ret
+entry:
+ %az = zext <4 x i8> %a to <4 x i32>
+ %bz = sext <4 x i8> %b to <4 x i32>
+ %m1 = mul nuw nsw <4 x i32> %az, %bz
+ %r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m1)
+ %cz = zext <4 x i8> %c to <4 x i32>
+ %dz = sext <4 x i8> %d to <4 x i32>
+ %m2 = mul nuw nsw <4 x i32> %cz, %dz
+ %r2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m2)
+ %x = add i32 %r1, %r2
+ ret i32 %x
+}
+
define i32 @test_udot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
; CHECK-LABEL: test_udot_v5i8:
; CHECK: // %bb.0: // %entry
@@ -414,6 +542,65 @@ entry:
ret i32 %x
}
+define i32 @test_usdot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-LABEL: test_usdot_v5i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ldr d0, [x0]
+; CHECK-NEXT: ldr d1, [x1]
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEXT: smull2 v2.4s, v1.8h, v0.8h
+; CHECK-NEXT: mov v3.s[0], v2.s[0]
+; CHECK-NEXT: smlal v3.4s, v1.4h, v0.4h
+; CHECK-NEXT: addv s0, v3.4s
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: add w0, w8, w2
+; CHECK-NEXT: ret
+entry:
+ %0 = load <5 x i8>, ptr %a
+ %1 = zext <5 x i8> %0 to <5 x i32>
+ %2 = load <5 x i8>, ptr %b
+ %3 = sext <5 x i8> %2 to <5 x i32>
+ %4 = mul nsw <5 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_v5i8_double(<5 x i8> %a, <5 x i8> %b, <5 x i8> %c, <5 x i8> %d) {
+; CHECK-LABEL: test_usdot_v5i8_double:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEXT: sshll v3.8h, v3.8b, #0
+; CHECK-NEXT: movi v5.2d, #0000000000000000
+; CHECK-NEXT: movi v6.2d, #0000000000000000
+; CHECK-NEXT: smull2 v4.4s, v0.8h, v1.8h
+; CHECK-NEXT: smull2 v7.4s, v2.8h, v3.8h
+; CHECK-NEXT: mov v6.s[0], v4.s[0]
+; CHECK-NEXT: mov v5.s[0], v7.s[0]
+; CHECK-NEXT: smlal v6.4s, v0.4h, v1.4h
+; CHECK-NEXT: smlal v5.4s, v2.4h, v3.4h
+; CHECK-NEXT: add v0.4s, v6.4s, v5.4s
+; CHECK-NEXT: addv s0, v0.4s
+; CHECK-NEXT: fmov w0, s0
+; CHECK-NEXT: ret
+entry:
+ %az = zext <5 x i8> %a to <5 x i32>
+ %bz = sext <5 x i8> %b to <5 x i32>
+ %m1 = mul nuw nsw <5 x i32> %az, %bz
+ %r1 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m1)
+ %cz = zext <5 x i8> %c to <5 x i32>
+ %dz = sext <5 x i8> %d to <5 x i32>
+ %m2 = mul nuw nsw <5 x i32> %cz, %dz
+ %r2 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m2)
+ %x = add i32 %r1, %r2
+ ret i32 %x
+}
+
+
define i32 @test_udot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
; CHECK-LABEL: test_udot_v8i8:
; CHECK: // %bb.0: // %entry
@@ -508,6 +695,77 @@ entry:
ret i32 %2
}
+define i32 @test_usdot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_v8i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr d1, [x0]
+; CHECK-SD-NEXT: ldr d2, [x1]
+; CHECK-SD-NEXT: usdot v0.2s, v1.8b, v2.8b
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr d0, [x0]
+; CHECK-GI-NEXT: ldr d1, [x1]
+; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <8 x i8>, ptr %a
+ %1 = zext <8 x i8> %0 to <8 x i32>
+ %2 = load <8 x i8>, ptr %b
+ %3 = sext <8 x i8> %2 to <8 x i32>
+ %4 = mul nsw <8 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+ ret i32 %5
+}
+
+define i32 @test_usdot_swapped_operands_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr d1, [x0]
+; CHECK-SD-NEXT: ldr d2, [x1]
+; CHECK-SD-NEXT: usdot v0.2s, v2.8b, v1.8b
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr d0, [x0]
+; CHECK-GI-NEXT: ldr d1, [x1]
+; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: sshll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <8 x i8>, ptr %a
+ %1 = sext <8 x i8> %0 to <8 x i32>
+ %2 = load <8 x i8>, ptr %b
+ %3 = zext <8 x i8> %2 to <8 x i32>
+ %4 = mul nsw <8 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+ ret i32 %5
+}
define i32 @test_udot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
; CHECK-LABEL: test_udot_v16i8:
@@ -587,6 +845,101 @@ entry:
ret i32 %2
}
+define i32 @test_usdot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v16i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr q1, [x0]
+; CHECK-SD-NEXT: ldr q2, [x1]
+; CHECK-SD-NEXT: usdot v0.4s, v1.16b, v2.16b
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr q0, [x0]
+; CHECK-GI-NEXT: ldr q1, [x1]
+; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT: ushll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT: ushll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT: sshll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT: mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT: mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <16 x i8>, ptr %a
+ %1 = zext <16 x i8> %0 to <16 x i32>
+ %2 = load <16 x i8>, ptr %b
+ %3 = sext <16 x i8> %2 to <16 x i32>
+ %4 = mul nsw <16 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_swapped_operands_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr q1, [x0]
+; CHECK-SD-NEXT: ldr q2, [x1]
+; CHECK-SD-NEXT: usdot v0.4s, v2.16b, v1.16b
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr q0, [x0]
+; CHECK-GI-NEXT: ldr q1, [x1]
+; CHECK-GI-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT: sshll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT: sshll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT: ushll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT: ushll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT: mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT: mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <16 x i8>, ptr %a
+ %1 = sext <16 x i8> %0 to <16 x i32>
+ %2 = load <16 x i8>, ptr %b
+ %3 = zext <16 x i8> %2 to <16 x i32>
+ %4 = mul nsw <16 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
; CHECK-SD-LABEL: test_udot_v8i8_double:
@@ -860,19 +1213,253 @@ entry:
ret i32 %x
}
-define i32 @test_udot_v24i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
-; CHECK-SD-LABEL: test_udot_v24i8:
+
+define i32 @test_usdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v8i8_double:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
-; CHECK-SD-NEXT: movi v1.2d, #0000000000000000
-; CHECK-SD-NEXT: ldr q2, [x0]
-; CHECK-SD-NEXT: ldr q3, [x1]
-; CHECK-SD-NEXT: ldr d4, [x0, #16]
-; CHECK-SD-NEXT: ldr d5, [x1, #16]
-; CHECK-SD-NEXT: udot v1.2s, v5.8b, v4.8b
-; CHECK-SD-NEXT: udot v0.4s, v3.16b, v2.16b
-; CHECK-SD-NEXT: addp v1.2s, v1.2s, v1.2s
-; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: movi v4.2d, #0000000000000000
+; CHECK-SD-NEXT: movi v5.2d, #0000000000000000
+; CHECK-SD-NEXT: usdot v5.2s, v0.8b, v1.8b
+; CHECK-SD-NEXT: usdot v4.2s, v2.8b, v3.8b
+; CHECK-SD-NEXT: add v0.2s, v5.2s, v4.2s
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8_double:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-GI-NEXT: sshll v3.8h, v3.8b, #0
+; CHECK-GI-NEXT: ushll2 v4.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v5.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll2 v6.4s, v2.8h, #0
+; CHECK-GI-NEXT: sshll2 v7.4s, v3.8h, #0
+; CHECK-...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/120094
More information about the llvm-commits
mailing list