[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 29 03:41:02 PDT 2024


================
@@ -0,0 +1,176 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    udot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    uunpklo z0.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    sunpklo z0.d, z2.s
+; CHECK-NEXT:    sunpkhi z1.d, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_8to64_accumulator:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    udot z4.s, z0.b, z1.b
+; CHECK-NEXT:    uunpklo z0.d, z4.s
+; CHECK-NEXT:    uunpkhi z1.d, z4.s
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_sext_8to64_accumulator:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    sdot z4.s, z0.b, z1.b
+; CHECK-NEXT:    sunpklo z0.d, z4.s
+; CHECK-NEXT:    sunpkhi z1.d, z4.s
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
----------------
SamTebbs33 wrote:

Ah that would be a mistake I made when copying the 8 to 64 tests. I've removed them as part of the commit that removes the nxv4i64 handling.

https://github.com/llvm/llvm-project/pull/101010


More information about the llvm-commits mailing list