[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 17 02:07:02 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: James Chesterman (JamesChesterman)
<details>
<summary>Changes</summary>
Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.
---
Full diff: https://github.com/llvm/llvm-project/pull/120207.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+50-32)
- (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+160)
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+78)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..1e5b80174e8cfd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21741,45 +21741,63 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
- return SDValue();
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
+ unsigned MulOpcode = MulOp->getOpcode();
+ EVT ReducedVT = N->getValueType(0);
+ EVT MulOpVT = MulOp->getValueType(0);
+ unsigned Opcode = 0;
+ bool AIsSigned, BIsSigned;
+ SDValue A, B;
+ if (MulOpcode != ISD::MUL && ReducedVT.getVectorElementCount() * 4 ==
+ MulOpVT.getVectorElementCount()) {
+ if (!ISD::isExtOpcode(MulOpcode))
+ return SDValue();
+ AIsSigned = MulOpcode == ISD::SIGN_EXTEND;
+ BIsSigned = AIsSigned;
+ SDValue NewMulOp = MulOp->getOperand(0);
+ Opcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ A = NewMulOp;
+ B = DAG.getConstant(1, DL, NewMulOp.getValueType());
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
- !ISD::isExtOpcode(ExtB->getOpcode()))
- return SDValue();
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ } else {
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- if (A.getValueType() != B.getValueType())
- return SDValue();
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
+ !ISD::isExtOpcode(ExtB->getOpcode()))
+ return SDValue();
+ AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+
+ A = ExtA->getOperand(0);
+ B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+ }
+
+ EVT MulSrcVT = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
- unsigned Opcode = 0;
if (AIsSigned != BIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
@@ -21793,19 +21811,19 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
- EVT ReducedTypeI32 =
- (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
-
- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
- DAG.getConstant(0, DL, ReducedTypeI32), A, B);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT ReducedVTI32 =
+ (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ auto DotI32 = DAG.getNode(Opcode, DL, ReducedVTI32,
+ DAG.getConstant(0, DL, ReducedVTI32), A, B);
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
Extended);
}
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ return DAG.getNode(Opcode, DL, ReducedVT, NarrowOp, A, B);
}
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index c1b9a4c9dbb797..c0987582efc261 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -367,6 +367,166 @@ entry:
ret <4 x i64> %partial.reduce
}
+define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <16 x i8> %a to <16 x i32>
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+ ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <16 x i8> %a to <16 x i32>
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.8b, #1
+; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <8 x i8> %a to <8 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+ ret <2 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.8b, #1
+; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <8 x i8> %a to <8 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v3.16b, #1
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v3.16b, #1
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+ ret <4 x i64> %partial.reduce
+}
+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66d6e0388bbf94..56cd2c9d62b04f 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -316,6 +316,84 @@ entry:
ret <vscale x 4 x i64> %partial.reduce
}
+define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: udot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: sdot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
``````````
</details>
https://github.com/llvm/llvm-project/pull/120207
More information about the llvm-commits
mailing list