[llvm] [AArch64][NEON] Lower fixed-width add partial reductions to dot product (PR #107078)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 5 06:38:15 PDT 2024


================
@@ -0,0 +1,209 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s
+; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefix=CHECK-NODOTPROD
+
+define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: udot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    udot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: udot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: udot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    udot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: udot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <8 x i8> %u to <8 x i32>
+  %s.wide = zext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: sdot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: sdot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <16 x i8> %u to <16 x i32>
+  %s.wide = sext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: sdot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sdot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: sdot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = sext <8 x i8> %u to <8 x i32>
+  %s.wide = sext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: not_udot:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <8 x i8> %u to <8 x i32>
+  %s.wide = zext <8 x i8> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) #0{
+; CHECK-LABEL: not_udot_narrow:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    bic v1.4h, #255, lsl #8
+; CHECK-NEXT:    bic v2.4h, #255, lsl #8
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot_narrow:
+; CHECK-NODOTPROD:       // %bb.0:
+; CHECK-NODOTPROD-NEXT:    bic v1.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT:    bic v2.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT:    umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT:    ret
+  %u.wide = zext <4 x i8> %u to <4 x i32>
+  %s.wide = zext <4 x i8> %s to <4 x i32>
+  %mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_sdot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
----------------
paulwalker-arm wrote:

Restricting my comment to `shouldExpandPartialReductionIntrinsic` was an error, but what I meant is the only difference between `not_udot` vs `not_sdot` and `not_udot_narrow` vs `not_sdot_narrow` is the type of extension and so there is nothing in `not_sdot` and `not_sdot_narrow` that's not already being tested by `not_udot` and `not_udot_narrow`.

https://github.com/llvm/llvm-project/pull/107078


More information about the llvm-commits mailing list