[llvm] [AArch64] Extend vecreduce to udot/sdot transformation to support usdot (PR #120094)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 06:54:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Igor Kirillov (igogo-x86)

<details>
<summary>Changes</summary>



---

Patch is 226.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120094.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+27-7) 
- (modified) llvm/test/CodeGen/AArch64/neon-dotreduce.ll (+4445-997) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c19265613c706d..3ef6ef356465d3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18283,16 +18283,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
   unsigned ExtOpcode = Op0.getOpcode();
   SDValue A = Op0;
   SDValue B;
+  unsigned DotOpcode;
   if (ExtOpcode == ISD::MUL) {
     A = Op0.getOperand(0);
     B = Op0.getOperand(1);
-    if (A.getOpcode() != B.getOpcode() ||
-        A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
+    if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
       return SDValue();
-    ExtOpcode = A.getOpcode();
-  }
-  if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
+    auto OpCodeA = A.getOpcode();
+    if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
+      return SDValue();
+
+    auto OpCodeB = B.getOpcode();
+    if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
+      return SDValue();
+
+    if (OpCodeA == OpCodeB) {
+      DotOpcode =
+          OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
+    } else {
+      // Check USDOT support support
+      if (!ST->hasMatMulInt8())
+        return SDValue();
+      DotOpcode = AArch64ISD::USDOT;
+      if (OpCodeA == ISD::SIGN_EXTEND)
+        std::swap(A, B);
+    }
+  } else if (ExtOpcode == ISD::ZERO_EXTEND) {
+    DotOpcode = AArch64ISD::UDOT;
+  } else if (ExtOpcode == ISD::SIGN_EXTEND) {
+    DotOpcode = AArch64ISD::SDOT;
+  } else {
     return SDValue();
+  }
 
   EVT Op0VT = A.getOperand(0).getValueType();
   bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18318,8 +18340,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
     NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
     TargetType = MVT::v2i32;
   }
-  auto DotOpcode =
-      (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
   // Handle the case where we need to generate only one Dot operation.
   if (NumOfVecReduce == 1) {
     SDValue Zeros = DAG.getConstant(0, DL, TargetType);
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index c345c1e50bbbb7..05ac2956da00c7 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1,22 +1,28 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod    < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm    < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
 
 ; CHECK-GI:       warning: Instruction selection used fallback path for test_udot_v5i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_udot_v5i8_nomla
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v5i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v5i8_double
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v5i8_double_nomla
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v5i8
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v5i8_double
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_udot_v25i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_udot_v25i8_nomla
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v25i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v25i8_double
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v25i8_double_nomla
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v25i8
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v25i8_double
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_udot_v33i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_udot_v33i8_nomla
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v33i8
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v33i8_double
 ; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_sdot_v33i8_double_nomla
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v33i8
+; CHECK-GI-NEXT:  warning: Instruction selection used fallback path for test_usdot_v33i8_double
 
 declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
 declare i32 @llvm.vector.reduce.add.v5i32(<5 x i32>)
@@ -290,6 +296,128 @@ entry:
   ret i32 %x
 }
 
+define i32 @test_usdot_v4i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v4i8:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    ldr s0, [x0]
+; CHECK-SD-NEXT:    ldr s1, [x1]
+; CHECK-SD-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-SD-NEXT:    smull v0.4s, v1.4h, v0.4h
+; CHECK-SD-NEXT:    addv s0, v0.4s
+; CHECK-SD-NEXT:    fmov w8, s0
+; CHECK-SD-NEXT:    add w0, w8, w2
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ldr w8, [x0]
+; CHECK-GI-NEXT:    ldr w9, [x1]
+; CHECK-GI-NEXT:    fmov s0, w8
+; CHECK-GI-NEXT:    fmov s2, w9
+; CHECK-GI-NEXT:    uxtb w8, w8
+; CHECK-GI-NEXT:    sxtb w9, w9
+; CHECK-GI-NEXT:    mov b1, v0.b[1]
+; CHECK-GI-NEXT:    mov b3, v0.b[2]
+; CHECK-GI-NEXT:    mov b5, v2.b[2]
+; CHECK-GI-NEXT:    mov b4, v0.b[3]
+; CHECK-GI-NEXT:    mov b0, v2.b[1]
+; CHECK-GI-NEXT:    mov b6, v2.b[3]
+; CHECK-GI-NEXT:    fmov s2, w9
+; CHECK-GI-NEXT:    fmov w10, s1
+; CHECK-GI-NEXT:    fmov w11, s3
+; CHECK-GI-NEXT:    fmov s1, w8
+; CHECK-GI-NEXT:    fmov w13, s5
+; CHECK-GI-NEXT:    fmov w8, s4
+; CHECK-GI-NEXT:    fmov w12, s0
+; CHECK-GI-NEXT:    uxtb w10, w10
+; CHECK-GI-NEXT:    uxtb w11, w11
+; CHECK-GI-NEXT:    sxtb w13, w13
+; CHECK-GI-NEXT:    uxtb w8, w8
+; CHECK-GI-NEXT:    sxtb w12, w12
+; CHECK-GI-NEXT:    mov v1.h[1], w10
+; CHECK-GI-NEXT:    fmov w10, s6
+; CHECK-GI-NEXT:    fmov s0, w11
+; CHECK-GI-NEXT:    fmov s3, w13
+; CHECK-GI-NEXT:    mov v2.h[1], w12
+; CHECK-GI-NEXT:    sxtb w10, w10
+; CHECK-GI-NEXT:    mov v0.h[1], w8
+; CHECK-GI-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    mov v3.h[1], w10
+; CHECK-GI-NEXT:    sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT:    mov v1.d[1], v0.d[0]
+; CHECK-GI-NEXT:    mov v2.d[1], v3.d[0]
+; CHECK-GI-NEXT:    mul v0.4s, v2.4s, v1.4s
+; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    add w0, w8, w2
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = load <4 x i8>, ptr %a
+  %1 = zext <4 x i8> %0 to <4 x i32>
+  %2 = load <4 x i8>, ptr %b
+  %3 = sext <4 x i8> %2 to <4 x i32>
+  %4 = mul nsw <4 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %4)
+  %op.extra = add nsw i32 %5, %sum
+  ret i32 %op.extra
+}
+
+define i32 @test_usdot_v4i8_double(<4 x i8> %a, <4 x i8> %b, <4 x i8> %c, <4 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v4i8_double:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-SD-NEXT:    bic v2.4h, #255, lsl #8
+; CHECK-SD-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-SD-NEXT:    bic v0.4h, #255, lsl #8
+; CHECK-SD-NEXT:    shl v3.4s, v3.4s, #24
+; CHECK-SD-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-SD-NEXT:    shl v1.4s, v1.4s, #24
+; CHECK-SD-NEXT:    ushll v0.4s, v0.4h, #0
+; CHECK-SD-NEXT:    sshr v3.4s, v3.4s, #24
+; CHECK-SD-NEXT:    sshr v1.4s, v1.4s, #24
+; CHECK-SD-NEXT:    mul v2.4s, v2.4s, v3.4s
+; CHECK-SD-NEXT:    mla v2.4s, v0.4s, v1.4s
+; CHECK-SD-NEXT:    addv s0, v2.4s
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8_double:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT:    movi v4.2d, #0x0000ff000000ff
+; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT:    shl v1.4s, v1.4s, #24
+; CHECK-GI-NEXT:    shl v3.4s, v3.4s, #24
+; CHECK-GI-NEXT:    and v0.16b, v0.16b, v4.16b
+; CHECK-GI-NEXT:    and v2.16b, v2.16b, v4.16b
+; CHECK-GI-NEXT:    sshr v1.4s, v1.4s, #24
+; CHECK-GI-NEXT:    sshr v3.4s, v3.4s, #24
+; CHECK-GI-NEXT:    mul v0.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    mul v1.4s, v2.4s, v3.4s
+; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    addv s1, v1.4s
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    fmov w9, s1
+; CHECK-GI-NEXT:    add w0, w8, w9
+; CHECK-GI-NEXT:    ret
+entry:
+  %az = zext <4 x i8> %a to <4 x i32>
+  %bz = sext <4 x i8> %b to <4 x i32>
+  %m1 = mul nuw nsw <4 x i32> %az, %bz
+  %r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m1)
+  %cz = zext <4 x i8> %c to <4 x i32>
+  %dz = sext <4 x i8> %d to <4 x i32>
+  %m2 = mul nuw nsw <4 x i32> %cz, %dz
+  %r2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m2)
+  %x = add i32 %r1, %r2
+  ret i32 %x
+}
+
 define i32 @test_udot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
 ; CHECK-LABEL: test_udot_v5i8:
 ; CHECK:       // %bb.0: // %entry
@@ -414,6 +542,65 @@ entry:
   ret i32 %x
 }
 
+define i32 @test_usdot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-LABEL: test_usdot_v5i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ldr d0, [x0]
+; CHECK-NEXT:    ldr d1, [x1]
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    smull2 v2.4s, v1.8h, v0.8h
+; CHECK-NEXT:    mov v3.s[0], v2.s[0]
+; CHECK-NEXT:    smlal v3.4s, v1.4h, v0.4h
+; CHECK-NEXT:    addv s0, v3.4s
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    add w0, w8, w2
+; CHECK-NEXT:    ret
+entry:
+  %0 = load <5 x i8>, ptr %a
+  %1 = zext <5 x i8> %0 to <5 x i32>
+  %2 = load <5 x i8>, ptr %b
+  %3 = sext <5 x i8> %2 to <5 x i32>
+  %4 = mul nsw <5 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %4)
+  %op.extra = add nsw i32 %5, %sum
+  ret i32 %op.extra
+}
+
+define i32 @test_usdot_v5i8_double(<5 x i8> %a, <5 x i8> %b, <5 x i8> %c, <5 x i8> %d) {
+; CHECK-LABEL: test_usdot_v5i8_double:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    sshll v3.8h, v3.8b, #0
+; CHECK-NEXT:    movi v5.2d, #0000000000000000
+; CHECK-NEXT:    movi v6.2d, #0000000000000000
+; CHECK-NEXT:    smull2 v4.4s, v0.8h, v1.8h
+; CHECK-NEXT:    smull2 v7.4s, v2.8h, v3.8h
+; CHECK-NEXT:    mov v6.s[0], v4.s[0]
+; CHECK-NEXT:    mov v5.s[0], v7.s[0]
+; CHECK-NEXT:    smlal v6.4s, v0.4h, v1.4h
+; CHECK-NEXT:    smlal v5.4s, v2.4h, v3.4h
+; CHECK-NEXT:    add v0.4s, v6.4s, v5.4s
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    fmov w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %az = zext <5 x i8> %a to <5 x i32>
+  %bz = sext <5 x i8> %b to <5 x i32>
+  %m1 = mul nuw nsw <5 x i32> %az, %bz
+  %r1 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m1)
+  %cz = zext <5 x i8> %c to <5 x i32>
+  %dz = sext <5 x i8> %d to <5 x i32>
+  %m2 = mul nuw nsw <5 x i32> %cz, %dz
+  %r2 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m2)
+  %x = add i32 %r1, %r2
+  ret i32 %x
+}
+
+
 define i32 @test_udot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
 ; CHECK-LABEL: test_udot_v8i8:
 ; CHECK:       // %bb.0: // %entry
@@ -508,6 +695,77 @@ entry:
   ret i32 %2
 }
 
+define i32 @test_usdot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_v8i8:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT:    ldr d1, [x0]
+; CHECK-SD-NEXT:    ldr d2, [x1]
+; CHECK-SD-NEXT:    usdot v0.2s, v1.8b, v2.8b
+; CHECK-SD-NEXT:    addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ldr d0, [x0]
+; CHECK-GI-NEXT:    ldr d1, [x1]
+; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT:    ushll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT:    mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = load <8 x i8>, ptr %a
+  %1 = zext <8 x i8> %0 to <8 x i32>
+  %2 = load <8 x i8>, ptr %b
+  %3 = sext <8 x i8> %2 to <8 x i32>
+  %4 = mul nsw <8 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+  ret i32 %5
+}
+
+define i32 @test_usdot_swapped_operands_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT:    ldr d1, [x0]
+; CHECK-SD-NEXT:    ldr d2, [x1]
+; CHECK-SD-NEXT:    usdot v0.2s, v2.8b, v1.8b
+; CHECK-SD-NEXT:    addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ldr d0, [x0]
+; CHECK-GI-NEXT:    ldr d1, [x1]
+; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT:    sshll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT:    sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT:    mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = load <8 x i8>, ptr %a
+  %1 = sext <8 x i8> %0 to <8 x i32>
+  %2 = load <8 x i8>, ptr %b
+  %3 = zext <8 x i8> %2 to <8 x i32>
+  %4 = mul nsw <8 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+  ret i32 %5
+}
 
 define i32 @test_udot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
 ; CHECK-LABEL: test_udot_v16i8:
@@ -587,6 +845,101 @@ entry:
   ret i32 %2
 }
 
+define i32 @test_usdot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v16i8:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT:    ldr q1, [x0]
+; CHECK-SD-NEXT:    ldr q2, [x1]
+; CHECK-SD-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; CHECK-SD-NEXT:    addv s0, v0.4s
+; CHECK-SD-NEXT:    fmov w8, s0
+; CHECK-SD-NEXT:    add w0, w8, w2
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ldr q0, [x0]
+; CHECK-GI-NEXT:    ldr q1, [x1]
+; CHECK-GI-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT:    ushll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT:    ushll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT:    sshll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT:    sshll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT:    ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT:    sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT:    mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT:    mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT:    mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    add w0, w8, w2
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = load <16 x i8>, ptr %a
+  %1 = zext <16 x i8> %0 to <16 x i32>
+  %2 = load <16 x i8>, ptr %b
+  %3 = sext <16 x i8> %2 to <16 x i32>
+  %4 = mul nsw <16 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+  %op.extra = add nsw i32 %5, %sum
+  ret i32 %op.extra
+}
+
+define i32 @test_usdot_swapped_operands_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT:    ldr q1, [x0]
+; CHECK-SD-NEXT:    ldr q2, [x1]
+; CHECK-SD-NEXT:    usdot v0.4s, v2.16b, v1.16b
+; CHECK-SD-NEXT:    addv s0, v0.4s
+; CHECK-SD-NEXT:    fmov w8, s0
+; CHECK-SD-NEXT:    add w0, w8, w2
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ldr q0, [x0]
+; CHECK-GI-NEXT:    ldr q1, [x1]
+; CHECK-GI-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT:    sshll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT:    sshll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT:    ushll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT:    ushll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT:    sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT:    sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT:    ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT:    ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT:    mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT:    mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT:    mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT:    mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    add w0, w8, w2
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = load <16 x i8>, ptr %a
+  %1 = sext <16 x i8> %0 to <16 x i32>
+  %2 = load <16 x i8>, ptr %b
+  %3 = zext <16 x i8> %2 to <16 x i32>
+  %4 = mul nsw <16 x i32> %3, %1
+  %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+  %op.extra = add nsw i32 %5, %sum
+  ret i32 %op.extra
+}
 
 define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-SD-LABEL: test_udot_v8i8_double:
@@ -860,19 +1213,253 @@ entry:
   ret i32 %x
 }
 
-define i32 @test_udot_v24i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
-; CHECK-SD-LABEL: test_udot_v24i8:
+
+define i32 @test_usdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v8i8_double:
 ; CHECK-SD:       // %bb.0: // %entry
-; CHECK-SD-NEXT:    movi v0.2d, #0000000000000000
-; CHECK-SD-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-SD-NEXT:    ldr q2, [x0]
-; CHECK-SD-NEXT:    ldr q3, [x1]
-; CHECK-SD-NEXT:    ldr d4, [x0, #16]
-; CHECK-SD-NEXT:    ldr d5, [x1, #16]
-; CHECK-SD-NEXT:    udot v1.2s, v5.8b, v4.8b
-; CHECK-SD-NEXT:    udot v0.4s, v3.16b, v2.16b
-; CHECK-SD-NEXT:    addp v1.2s, v1.2s, v1.2s
-; CHECK-SD-NEXT:    addv s0, v0.4s
+; CHECK-SD-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-SD-NEXT:    movi v5.2d, #0000000000000000
+; CHECK-SD-NEXT:    usdot v5.2s, v0.8b, v1.8b
+; CHECK-SD-NEXT:    usdot v4.2s, v2.8b, v3.8b
+; CHECK-SD-NEXT:    add v0.2s, v5.2s, v4.2s
+; CHECK-SD-NEXT:    addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8_double:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-GI-NEXT:    sshll v3.8h, v3.8b, #0
+; CHECK-GI-NEXT:    ushll2 v4.4s, v0.8h, #0
+; CHECK-GI-NEXT:    sshll2 v5.4s, v1.8h, #0
+; CHECK-GI-NEXT:    ushll2 v6.4s, v2.8h, #0
+; CHECK-GI-NEXT:    sshll2 v7.4s, v3.8h, #0
+; CHECK-...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/120094


More information about the llvm-commits mailing list