[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 31 03:48:25 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: None (chuongg3)
<details>
<summary>Changes</summary>
vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)
Vectors of scalar size of 8-bits with element count of multiples of 8
---
Patch is 25.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70784.diff
4 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+11-1)
- (modified) llvm/lib/Target/AArch64/AArch64InstrGISel.td (+15)
- (modified) llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp (+140)
- (modified) llvm/test/CodeGen/AArch64/vecreduce-add.ll (+220-113)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 017c4523c23a184..e17524b2c55bdd3 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -33,12 +33,22 @@ def fold_global_offset : GICombineRule<
(apply [{ applyFoldGlobalOffset(*${root}, MRI, B, Observer, ${matchinfo});}])
>;
+let Predicates = [HasDotProd] in {
+def ext_addv_to_udot_addv : GICombineRule<
+ (defs root:$root),
+ (match (wip_match_opcode G_VECREDUCE_ADD):$root,
+ [{ return matchExtAddvToUdotAddv(*${root}, MRI); }]),
+ (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer); }])
+>;
+}
+
def AArch64PreLegalizerCombiner: GICombiner<
"AArch64PreLegalizerCombinerImpl", [all_combines,
fconstant_to_constant,
icmp_redundant_trunc,
fold_global_offset,
- shuffle_to_extract]> {
+ shuffle_to_extract,
+ ext_addv_to_udot_addv]> {
let CombineAllMethodName = "tryCombineAllImpl";
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 27338bd24393325..1711360779bf74c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
let hasSideEffects = 0;
}
+def G_UDOT : AArch64GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+ let hasSideEffects = 0;
+}
+
+def G_SDOT : AArch64GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+ let hasSideEffects = 0;
+}
+
// Generic instruction for the BSP pseudo. It is expanded into BSP, which
// expands into BSL/BIT/BIF after register allocation.
def G_BSP : AArch64GenericInstruction {
@@ -270,6 +282,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
def : GINodeEquiv<G_UMULL, AArch64umull>;
def : GINodeEquiv<G_SMULL, AArch64smull>;
+def : GINodeEquiv<G_UDOT, AArch64udot>;
+def : GINodeEquiv<G_SDOT, AArch64sdot>;
+
def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;
def : GINodeEquiv<G_PREFETCH, AArch64Prefetch>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
index d9678bea214dd53..34a59839a99a97c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register DstReg = MI.getOperand(0).getReg();
+ Register MidReg = I1->getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT MidTy = MRI.getType(MidReg);
+ if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+ return false;
+
+ LLT SrcTy;
+ auto I1Opc = I1->getOpcode();
+ if (I1Opc == TargetOpcode::G_MUL) {
+ MachineInstr *ExtMI1 =
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ MachineInstr *ExtMI2 =
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+ LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+ if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+ return false;
+ I1Opc = ExtMI1->getOpcode();
+ SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+ } else
+ SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+ if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+ return false;
+ if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+ return false;
+
+ return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &Builder,
+ GISelChangeObserver &Observer) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register Ext1SrcReg, Ext2SrcReg;
+ unsigned DotOpcode;
+ if (I1->getOpcode() == TargetOpcode::G_MUL) {
+ auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+ Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+ DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+ I1->getOpcode() == TargetOpcode::G_SEXT) {
+ Ext1SrcReg = I1->getOperand(1).getReg();
+ Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+ ->getOperand(0)
+ .getReg();
+ DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else
+ return;
+
+ LLT SrcTy = MRI.getType(Ext1SrcReg);
+ LLT MidTy;
+ unsigned NumOfVecReduce;
+ if (SrcTy.getNumElements() % 16 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 16;
+ MidTy = LLT::fixed_vector(4, 32);
+ } else if (SrcTy.getNumElements() % 8 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 8;
+ MidTy = LLT::fixed_vector(2, 32);
+ } else
+ return;
+
+ // Handle case where one DOT instruction is needed
+ if (NumOfVecReduce == 1) {
+ auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+ auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+ {Zeroes, Ext1SrcReg, Ext2SrcReg});
+ Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+ } else {
+ // Get the number of output vectors needed
+ SmallVector<LLT, 4> DotVecLLT;
+ auto SrcVecNum = SrcTy.getNumElements();
+ while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
+ DotVecLLT.push_back(LLT::fixed_vector(16, 8));
+ SrcVecNum = SrcVecNum - 16;
+ }
+ if (SrcVecNum == 8)
+ DotVecLLT.push_back(LLT::fixed_vector(8, 8));
+
+ // Unmerge the source vectors
+ auto Ext1Unmerge = Builder.buildUnmerge(DotVecLLT, Ext1SrcReg);
+ auto Ext2Unmerge = Builder.buildUnmerge(DotVecLLT, Ext2SrcReg);
+
+ // Build the UDOT instructions
+ SmallVector<Register, 2> DotReg;
+ unsigned NumElements = 0;
+ for (unsigned i = 0; i < DotVecLLT.size(); i++) {
+ LLT ZeroesLLT;
+ // Check if it is 16 or 8 elements. Set Zeroes to the accoridng size
+ if (MRI.getType(Ext1Unmerge.getReg(i)).getNumElements() == 16) {
+ ZeroesLLT = LLT::fixed_vector(4, 32);
+ NumElements += 4;
+ } else {
+ ZeroesLLT = LLT::fixed_vector(2, 32);
+ NumElements += 2;
+ }
+ auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
+ DotReg.push_back(Builder
+ .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
+ {Zeroes, Ext1Unmerge.getReg(i),
+ Ext2Unmerge.getReg(i)})
+ ->getOperand(0)
+ .getReg());
+ }
+
+ // Merge the output
+ // auto a = MI.getOperand(1).getReg().changeNumElements(NumElements);
+ auto ConcatMI =
+ Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg);
+
+ // Put it through a vector reduction
+ Builder.buildVecReduceAdd(MI.getOperand(0).getReg(),
+ ConcatMI->getOperand(0).getReg());
+ }
+
+ // Erase the dead instructions
+ if (I1->getOpcode() == TargetOpcode::G_MUL) {
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent();
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI)->eraseFromParent();
+ }
+ I1->eraseFromParent();
+ MI.eraseFromParent();
+}
+
bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
CombinerHelper &Helper, GISelChangeObserver &Observer) {
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index a88c930d09e9b17..b4b221bf4e46461 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -440,14 +440,10 @@ define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -479,14 +475,10 @@ define i32 @add_v16i8_v16i32_sext(<16 x i8> %x) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -514,10 +506,10 @@ define i32 @add_v8i8_v8i32_zext(<8 x i8> %x) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -545,10 +537,10 @@ define i32 @add_v8i8_v8i32_sext(<8 x i8> %x) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w0, s0
; CHECK-GI-NEXT: ret
entry:
@@ -1560,14 +1552,10 @@ define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_acc_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1603,14 +1591,10 @@ define i32 @add_v16i8_v16i32_acc_sext(<16 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v16i8_v16i32_acc_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT: sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT: saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.16b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1642,10 +1626,10 @@ define i32 @add_v8i8_v8i32_acc_zext(<8 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_acc_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -1677,10 +1661,10 @@ define i32 @add_v8i8_v8i32_acc_sext(<8 x i8> %x, i32 %a) {
;
; CHECK-GI-LABEL: add_v8i8_v8i32_acc_sext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT: sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT: saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: movi v1.8b, #1
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: add w0, w8, w0
; CHECK-GI-NEXT: ret
@@ -2618,6 +2602,152 @@ entry:
ret i32 %z
}
+define i32 @test_udot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v8i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT: umull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: addv s0, v2.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_udot_v8i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_udot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = zext <8 x i8> %a to <8 x i32>
+ %1 = zext <8 x i8> %b to <8 x i32>
+ %2 = mul nuw nsw <8 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_udot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v16i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT: umull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT: umull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT: umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: umlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT: addv s0, v0.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_udot_v16i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT: addv s0, v2.4s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_udot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: udot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = zext <16 x i8> %a to <16 x i32>
+ %1 = zext <16 x i8> %b to <16 x i32>
+ %2 = mul nuw nsw <16 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_sdot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v8i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: sshll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT: smull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: addv s0, v2.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_sdot_v8i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_sdot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT: addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = sext <8 x i8> %a to <8 x i32>
+ %1 = sext <8 x i8> %b to <8 x i32>
+ %2 = mul nuw nsw <8 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+ ret i32 %3
+}
+
+define i32 @test_sdot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v16i8:
+; CHECK-BASE: // %bb.0: // %entry
+; CHECK-BASE-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT: smull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT: smull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT: smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT: smlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT: add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT: addv s0, v0.4s
+; CHECK-BASE-NEXT: fmov w0, s0
+; CHECK-BASE-NEXT: ret
+;
+; CHECK-DOT-LABEL: test_sdot_v16i8:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT: addv s0, v2.4s
+; CHECK-DOT-NEXT: fmov w0, s0
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-GI-LABEL: test_sdot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT: sdot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = sext <16 x i8> %a to <16 x i32>
+ %1 = sext <16 x i8> %b to <16 x i32>
+ %2 = mul nuw nsw <16 x i32> %1, %0
+ %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+ ret i32 %3
+}
+
define zeroext i16 @add_pair_v8i16_v8i16(<8 x i16> %x, <8 x i16> %y) {
; CHECK-BASE-LABEL: add_pair_v8i16_v8i16:
; CHECK-BASE: // %bb.0: // %entry
@@ -2990,22 +3120,13 @@ define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
;
; CHECK-GI-LABEL: add_pair_v16i8_v16i32_zext:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-GI-NEXT: ushll2 v1.8h, v1....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/70784
More information about the llvm-commits
mailing list