[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 20 07:06:35 PST 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/120207

>From 0619820e150b67483d9c5ae2b8a2767551bceab5 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 17 Dec 2024 10:00:36 +0000
Subject: [PATCH 1/2] [AArch64][SVE] Add dot product codegen for partial
 reductions with no binary operation on input

Add codegen for when the input type has 4 times as many elements as
the output type and the input to the partial reduction does not
have a binary operation performed on it.
---
 .../neon-partial-reduce-dot-product.ll        | 160 ++++++++++++++++++
 .../AArch64/sve-partial-reduce-dot-product.ll |  78 +++++++++
 2 files changed, 238 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index c1b9a4c9dbb797..c0987582efc261 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -367,6 +367,166 @@ entry:
   ret <4 x i64> %partial.reduce
 }
 
+define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    ushll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = zext <16 x i8> %a to <16 x i32>
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+  ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    sdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT:    sshll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    saddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = sext <16 x i8> %a to <16 x i32>
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+  ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.8b, #1
+; CHECK-DOT-NEXT:    udot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = zext <8 x i8> %a to <8 x i32>
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+  ret <2 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v2.8b, #1
+; CHECK-DOT-NEXT:    sdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = sext <8 x i8> %a to <8 x i32>
+  %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+  ret <2 x i32> %partial.reduce
+}
+
+define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v3.16b, #1
+; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    ushll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    ushll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT:    uaddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT:    uaddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT:    add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = zext <16 x i8> %a to <16 x i64>
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-DOT:       // %bb.0:
+; CHECK-DOT-NEXT:    movi v3.16b, #1
+; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NODOT:       // %bb.0:
+; CHECK-NODOT-NEXT:    sshll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT:    sshll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT:    saddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT:    saddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT:    add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT:    ret
+  %a.wide = sext <16 x i8> %a to <16 x i64>
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+  ret <4 x i64> %partial.reduce
+}
+
 define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66d6e0388bbf94..56cd2c9d62b04f 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -316,6 +316,84 @@ entry:
   ret <vscale x 4 x i64> %partial.reduce
 }
 
+define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+  %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, #1 // =0x1
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+  %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: udot_no_bin_op_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: sdot_no_bin_op_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.h, #1 // =0x1
+; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op_8to64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEXT:    ret
+  %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op_8to64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEXT:    sunpklo z2.d, z4.s
+; CHECK-NEXT:    sunpkhi z3.d, z4.s
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    add z1.d, z1.d, z3.d
+; CHECK-NEXT:    ret
+  %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
 define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_udot:
 ; CHECK:       // %bb.0: // %entry

>From cc4fb09dfc6859df1681f1260fcd52c3f51e9e4b Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 20 Dec 2024 15:04:50 +0000
Subject: [PATCH 2/2] Address comments. This is a rebase on the NFC patch added
 for renaming variables.

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  54 ++++----
 .../neon-partial-reduce-dot-product.ll        |  88 ++++++++++++
 .../AArch64/sve-partial-reduce-dot-product.ll | 130 ++++++++++++++++++
 3 files changed, 249 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a27c030237c878..33896eb3567083 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21953,36 +21953,46 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   SDLoc DL(N);
 
   SDValue Op2 = N->getOperand(2);
-  if (Op2->getOpcode() != ISD::MUL ||
-      !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
-      !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
-    return SDValue();
+  unsigned Op2Opcode = Op2->getOpcode();
+  SDValue MulOpLHS, MulOpRHS;
+  bool MulOpLHSIsSigned, MulOpRHSIsSigned;
+  if (ISD::isExtOpcode(Op2Opcode)) {
+    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
+    MulOpLHS = Op2->getOperand(0);
+    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
+  } else if (Op2Opcode == ISD::MUL) {
+    SDValue ExtMulOpLHS = Op2->getOperand(0);
+    SDValue ExtMulOpRHS = Op2->getOperand(1);
+
+    unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+    unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+    if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+        !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+      return SDValue();
 
-  SDValue Acc = N->getOperand(1);
-  SDValue Mul = N->getOperand(2);
-  SDValue ExtMulOpLHS = Mul->getOperand(0);
-  SDValue ExtMulOpRHS = Mul->getOperand(1);
+    MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+    MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
 
-  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
-  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
-  if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+    MulOpLHS = ExtMulOpLHS->getOperand(0);
+    MulOpRHS = ExtMulOpRHS->getOperand(0);
+  } else
     return SDValue();
 
+  SDValue Acc = N->getOperand(1);
   EVT ReducedVT = N->getValueType(0);
   EVT MulSrcVT = MulOpLHS.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
-      !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+  if ((!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+       !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+       !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+       !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+       !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+       !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) ||
+      (MulOpLHS.getValueType() != MulOpRHS.getValueType()))
     return SDValue();
 
-  bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
-  bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
   // If the extensions are mixed, we should lower it to a usdot instead
   unsigned Opcode = 0;
   if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
@@ -21998,10 +22008,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     // USDOT expects the signed operand to be last
     if (!MulOpRHSIsSigned)
       std::swap(MulOpLHS, MulOpRHS);
-  } else if (MulOpLHSIsSigned)
-    Opcode = AArch64ISD::SDOT;
-  else
-    Opcode = AArch64ISD::UDOT;
+  } else
+    Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
   // product followed by a zero / sign extension
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index c0987582efc261..9ece9edb843439 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -558,3 +558,91 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
   %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
   ret <2 x i32> %partial.reduce
 }
+
+define <2 x i64> @udot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
+; CHECK-LABEL: udot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v3.4s, v1.4h, #0
+; CHECK-NEXT:    ushll2 v1.4s, v1.8h, #0
+; CHECK-NEXT:    ushll v4.4s, v2.4h, #0
+; CHECK-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NEXT:    umull v5.2d, v1.2s, v2.2s
+; CHECK-NEXT:    umlal v0.2d, v3.2s, v4.2s
+; CHECK-NEXT:    umlal2 v0.2d, v1.4s, v2.4s
+; CHECK-NEXT:    umlal2 v5.2d, v3.4s, v4.4s
+; CHECK-NEXT:    add v0.2d, v5.2d, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <8 x i16> %a to <8 x i64>
+  %b.wide = zext <8 x i8> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <2 x i64> @sdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
+; CHECK-LABEL: sdot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NEXT:    sshll v3.4s, v1.4h, #0
+; CHECK-NEXT:    sshll2 v1.4s, v1.8h, #0
+; CHECK-NEXT:    sshll v4.4s, v2.4h, #0
+; CHECK-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NEXT:    smull v5.2d, v1.2s, v2.2s
+; CHECK-NEXT:    smlal v0.2d, v3.2s, v4.2s
+; CHECK-NEXT:    smlal2 v0.2d, v1.4s, v2.4s
+; CHECK-NEXT:    smlal2 v5.2d, v3.4s, v4.4s
+; CHECK-NEXT:    add v0.2d, v5.2d, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <8 x i16> %a to <8 x i64>
+  %b.wide = sext <8 x i8> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <2 x i64> @usdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
+; CHECK-LABEL: usdot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v3.4s, v1.4h, #0
+; CHECK-NEXT:    ushll2 v1.4s, v1.8h, #0
+; CHECK-NEXT:    sshll v4.4s, v2.4h, #0
+; CHECK-NEXT:    sshll2 v2.4s, v2.8h, #0
+; CHECK-NEXT:    smull v5.2d, v1.2s, v2.2s
+; CHECK-NEXT:    smlal v0.2d, v3.2s, v4.2s
+; CHECK-NEXT:    smlal2 v0.2d, v1.4s, v2.4s
+; CHECK-NEXT:    smlal2 v5.2d, v3.4s, v4.4s
+; CHECK-NEXT:    add v0.2d, v5.2d, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <8 x i16> %a to <8 x i64>
+  %b.wide = sext <8 x i8> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <2 x i64> @sudot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
+; CHECK-LABEL: sudot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    sshll v3.4s, v1.4h, #0
+; CHECK-NEXT:    sshll2 v1.4s, v1.8h, #0
+; CHECK-NEXT:    ushll v4.4s, v2.4h, #0
+; CHECK-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NEXT:    smull v5.2d, v1.2s, v2.2s
+; CHECK-NEXT:    smlal v0.2d, v3.2s, v4.2s
+; CHECK-NEXT:    smlal2 v0.2d, v1.4s, v2.4s
+; CHECK-NEXT:    smlal2 v5.2d, v3.4s, v4.4s
+; CHECK-NEXT:    add v0.2d, v5.2d, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <8 x i16> %a to <8 x i64>
+  %b.wide = zext <8 x i8> %b to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 56cd2c9d62b04f..66f83c658ff4f2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -497,3 +497,133 @@ entry:
   %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }
+
+define <vscale x 2 x i64> @udot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: udot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @sdot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: sdot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEXT:    sunpklo z4.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @usdot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: usdot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z5.d, z3.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    uunpklo z7.d, z1.s
+; CHECK-NEXT:    sunpklo z4.s, z2.h
+; CHECK-NEXT:    sunpkhi z2.s, z2.h
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    sunpklo z6.d, z4.s
+; CHECK-NEXT:    sunpkhi z4.d, z4.s
+; CHECK-NEXT:    sunpklo z24.d, z2.s
+; CHECK-NEXT:    sunpkhi z2.d, z2.s
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @sudot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: sudot_different_types:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEXT:    sunpklo z3.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    sunpklo z5.d, z3.s
+; CHECK-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEXT:    sunpklo z7.d, z1.s
+; CHECK-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpklo z6.d, z4.s
+; CHECK-NEXT:    uunpkhi z4.d, z4.s
+; CHECK-NEXT:    uunpklo z24.d, z2.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mul z3.d, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    movprfx z1, z3
+; CHECK-NEXT:    mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEXT:    add z0.d, z1.d, z0.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}



More information about the llvm-commits mailing list