[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 31 03:47:16 PDT 2023


https://github.com/chuongg3 created https://github.com/llvm/llvm-project/pull/70784

vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot) 
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8

>From 7a53eae0e223085760d81fbf4dea6489224babd8 Mon Sep 17 00:00:00 2001
From: Tuan Chuong Goh <chuong.goh at arm.com>
Date: Mon, 30 Oct 2023 09:51:47 +0000
Subject: [PATCH] [AArch64][GlobalISel] Support udot lowering for vecreduce add

vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
vecreduce_add(ext) -> vecreduce_add(ext)

Vectors of scalar size of 8-bits with element count of multiples of 8
---
 llvm/lib/Target/AArch64/AArch64Combine.td     |  12 +-
 llvm/lib/Target/AArch64/AArch64InstrGISel.td  |  15 +
 .../GISel/AArch64PreLegalizerCombiner.cpp     | 140 ++++++++
 llvm/test/CodeGen/AArch64/vecreduce-add.ll    | 333 ++++++++++++------
 4 files changed, 386 insertions(+), 114 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 017c4523c23a184..e17524b2c55bdd3 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -33,12 +33,22 @@ def fold_global_offset : GICombineRule<
   (apply [{ applyFoldGlobalOffset(*${root}, MRI, B, Observer, ${matchinfo});}])
 >;
 
+let Predicates = [HasDotProd] in {
+def ext_addv_to_udot_addv : GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_VECREDUCE_ADD):$root,
+         [{ return matchExtAddvToUdotAddv(*${root}, MRI); }]),
+  (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer); }])
+>;
+}
+
 def AArch64PreLegalizerCombiner: GICombiner<
   "AArch64PreLegalizerCombinerImpl", [all_combines,
                                       fconstant_to_constant,
                                       icmp_redundant_trunc,
                                       fold_global_offset,
-                                      shuffle_to_extract]> {
+                                      shuffle_to_extract,
+                                      ext_addv_to_udot_addv]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 27338bd24393325..1711360779bf74c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -227,6 +227,18 @@ def G_SMULL : AArch64GenericInstruction {
   let hasSideEffects = 0;
 }
 
+def G_UDOT : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+  let hasSideEffects = 0;
+}
+
+def G_SDOT : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2, type0:$src3);
+  let hasSideEffects = 0;
+}
+
 // Generic instruction for the BSP pseudo. It is expanded into BSP, which
 // expands into BSL/BIT/BIF after register allocation.
 def G_BSP : AArch64GenericInstruction {
@@ -270,6 +282,9 @@ def : GINodeEquiv<G_BSP, AArch64bsp>;
 def : GINodeEquiv<G_UMULL, AArch64umull>;
 def : GINodeEquiv<G_SMULL, AArch64smull>;
 
+def : GINodeEquiv<G_UDOT, AArch64udot>;
+def : GINodeEquiv<G_SDOT, AArch64sdot>;
+
 def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;
 
 def : GINodeEquiv<G_PREFETCH, AArch64Prefetch>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
index d9678bea214dd53..34a59839a99a97c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
 }
 
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register DstReg = MI.getOperand(0).getReg();
+  Register MidReg = I1->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT MidTy = MRI.getType(MidReg);
+  if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+    return false;
+
+  LLT SrcTy;
+  auto I1Opc = I1->getOpcode();
+  if (I1Opc == TargetOpcode::G_MUL) {
+    MachineInstr *ExtMI1 =
+        getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    MachineInstr *ExtMI2 =
+        getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+    LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+    if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+      return false;
+    I1Opc = ExtMI1->getOpcode();
+    SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+  } else
+    SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+  if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+    return false;
+  if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+    return false;
+
+  return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                            MachineIRBuilder &Builder,
+                            GISelChangeObserver &Observer) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register Ext1SrcReg, Ext2SrcReg;
+  unsigned DotOpcode;
+  if (I1->getOpcode() == TargetOpcode::G_MUL) {
+    auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+    Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+    DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                            : AArch64::G_SDOT;
+  } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+             I1->getOpcode() == TargetOpcode::G_SEXT) {
+    Ext1SrcReg = I1->getOperand(1).getReg();
+    Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+                     ->getOperand(0)
+                     .getReg();
+    DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                        : AArch64::G_SDOT;
+  } else
+    return;
+
+  LLT SrcTy = MRI.getType(Ext1SrcReg);
+  LLT MidTy;
+  unsigned NumOfVecReduce;
+  if (SrcTy.getNumElements() % 16 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 16;
+    MidTy = LLT::fixed_vector(4, 32);
+  } else if (SrcTy.getNumElements() % 8 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 8;
+    MidTy = LLT::fixed_vector(2, 32);
+  } else
+    return;
+
+  // Handle case where one DOT instruction is needed
+  if (NumOfVecReduce == 1) {
+    auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+    auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+                                  {Zeroes, Ext1SrcReg, Ext2SrcReg});
+    Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+  } else {
+    // Get the number of output vectors needed
+    SmallVector<LLT, 4> DotVecLLT;
+    auto SrcVecNum = SrcTy.getNumElements();
+    while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
+      DotVecLLT.push_back(LLT::fixed_vector(16, 8));
+      SrcVecNum = SrcVecNum - 16;
+    }
+    if (SrcVecNum == 8)
+      DotVecLLT.push_back(LLT::fixed_vector(8, 8));
+
+    // Unmerge the source vectors
+    auto Ext1Unmerge = Builder.buildUnmerge(DotVecLLT, Ext1SrcReg);
+    auto Ext2Unmerge = Builder.buildUnmerge(DotVecLLT, Ext2SrcReg);
+
+    // Build the UDOT instructions
+    SmallVector<Register, 2> DotReg;
+    unsigned NumElements = 0;
+    for (unsigned i = 0; i < DotVecLLT.size(); i++) {
+      LLT ZeroesLLT;
+      // Check if it is 16 or 8 elements. Set Zeroes to the accoridng size
+      if (MRI.getType(Ext1Unmerge.getReg(i)).getNumElements() == 16) {
+        ZeroesLLT = LLT::fixed_vector(4, 32);
+        NumElements += 4;
+      } else {
+        ZeroesLLT = LLT::fixed_vector(2, 32);
+        NumElements += 2;
+      }
+      auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
+      DotReg.push_back(Builder
+                           .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
+                                       {Zeroes, Ext1Unmerge.getReg(i),
+                                        Ext2Unmerge.getReg(i)})
+                           ->getOperand(0)
+                           .getReg());
+    }
+
+    // Merge the output
+    // auto a = MI.getOperand(1).getReg().changeNumElements(NumElements);
+    auto ConcatMI =
+        Builder.buildConcatVectors(LLT::fixed_vector(NumElements, 32), DotReg);
+
+    // Put it through a vector reduction
+    Builder.buildVecReduceAdd(MI.getOperand(0).getReg(),
+                              ConcatMI->getOperand(0).getReg());
+  }
+
+  // Erase the dead instructions
+  if (I1->getOpcode() == TargetOpcode::G_MUL) {
+    getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI)->eraseFromParent();
+    getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI)->eraseFromParent();
+  }
+  I1->eraseFromParent();
+  MI.eraseFromParent();
+}
+
 bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
                         CombinerHelper &Helper, GISelChangeObserver &Observer) {
   // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index a88c930d09e9b17..b4b221bf4e46461 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -440,14 +440,10 @@ define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -479,14 +475,10 @@ define i32 @add_v16i8_v16i32_sext(<16 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -514,10 +506,10 @@ define i32 @add_v8i8_v8i32_zext(<8 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -545,10 +537,10 @@ define i32 @add_v8i8_v8i32_sext(<8 x i8> %x) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w0, s0
 ; CHECK-GI-NEXT:    ret
 entry:
@@ -1560,14 +1552,10 @@ define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_acc_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1603,14 +1591,10 @@ define i32 @add_v16i8_v16i32_acc_sext(<16 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v16i8_v16i32_acc_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v1.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v1.4s, v2.4s, v1.8h
-; CHECK-GI-NEXT:    saddw2 v0.4s, v3.4s, v0.8h
-; CHECK-GI-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.16b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v0.16b, v1.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1642,10 +1626,10 @@ define i32 @add_v8i8_v8i32_acc_zext(<8 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_acc_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -1677,10 +1661,10 @@ define i32 @add_v8i8_v8i32_acc_sext(<8 x i8> %x, i32 %a) {
 ;
 ; CHECK-GI-LABEL: add_v8i8_v8i32_acc_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v1.4s, v0.4h, #0
-; CHECK-GI-NEXT:    saddw2 v0.4s, v1.4s, v0.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
+; CHECK-GI-NEXT:    movi v1.8b, #1
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v0.8b, v1.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    add w0, w8, w0
 ; CHECK-GI-NEXT:    ret
@@ -2618,6 +2602,152 @@ entry:
   ret i32 %z
 }
 
+define i32 @test_udot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v8i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    umull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    addv s0, v2.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_udot_v8i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_udot_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = zext <8 x i8> %a to <8 x i32>
+  %1 = zext <8 x i8> %b to <8 x i32>
+  %2 = mul nuw nsw <8 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_udot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_udot_v16i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT:    umull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT:    umull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT:    umlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    umlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT:    addv s0, v0.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_udot_v16i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    udot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT:    addv s0, v2.4s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_udot_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = zext <16 x i8> %a to <16 x i32>
+  %1 = zext <16 x i8> %b to <16 x i32>
+  %2 = mul nuw nsw <16 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_sdot_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v8i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    smull v2.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    addv s0, v2.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_sdot_v8i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v2.2s, v1.8b, v0.8b
+; CHECK-DOT-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_sdot_v8i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.2s, v1.8b, v0.8b
+; CHECK-GI-NEXT:    addp v0.2s, v2.2s, v2.2s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = sext <8 x i8> %a to <8 x i32>
+  %1 = sext <8 x i8> %b to <8 x i32>
+  %2 = mul nuw nsw <8 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
+  ret i32 %3
+}
+
+define i32 @test_sdot_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-BASE-LABEL: test_sdot_v16i8:
+; CHECK-BASE:       // %bb.0: // %entry
+; CHECK-BASE-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-BASE-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-BASE-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-BASE-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-BASE-NEXT:    smull v4.4s, v3.4h, v2.4h
+; CHECK-BASE-NEXT:    smull2 v2.4s, v3.8h, v2.8h
+; CHECK-BASE-NEXT:    smlal2 v2.4s, v1.8h, v0.8h
+; CHECK-BASE-NEXT:    smlal v4.4s, v1.4h, v0.4h
+; CHECK-BASE-NEXT:    add v0.4s, v4.4s, v2.4s
+; CHECK-BASE-NEXT:    addv s0, v0.4s
+; CHECK-BASE-NEXT:    fmov w0, s0
+; CHECK-BASE-NEXT:    ret
+;
+; CHECK-DOT-LABEL: test_sdot_v16i8:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-DOT-NEXT:    sdot v2.4s, v1.16b, v0.16b
+; CHECK-DOT-NEXT:    addv s0, v2.4s
+; CHECK-DOT-NEXT:    fmov w0, s0
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-GI-LABEL: test_sdot_v16i8:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v2.4s, v1.16b, v0.16b
+; CHECK-GI-NEXT:    addv s0, v2.4s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %0 = sext <16 x i8> %a to <16 x i32>
+  %1 = sext <16 x i8> %b to <16 x i32>
+  %2 = mul nuw nsw <16 x i32> %1, %0
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  ret i32 %3
+}
+
 define zeroext i16 @add_pair_v8i16_v8i16(<8 x i16> %x, <8 x i16> %y) {
 ; CHECK-BASE-LABEL: add_pair_v8i16_v8i16:
 ; CHECK-BASE:       // %bb.0: // %entry
@@ -2990,22 +3120,13 @@ define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
 ;
 ; CHECK-GI-LABEL: add_pair_v16i8_v16i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v2.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-GI-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-GI-NEXT:    ushll v4.4s, v2.4h, #0
-; CHECK-GI-NEXT:    ushll v5.4s, v0.4h, #0
-; CHECK-GI-NEXT:    ushll v6.4s, v3.4h, #0
-; CHECK-GI-NEXT:    ushll v7.4s, v1.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v2.4s, v4.4s, v2.8h
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v5.4s, v0.8h
-; CHECK-GI-NEXT:    uaddw2 v3.4s, v6.4s, v3.8h
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v7.4s, v1.8h
-; CHECK-GI-NEXT:    add v0.4s, v2.4s, v0.4s
-; CHECK-GI-NEXT:    add v1.4s, v3.4s, v1.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
-; CHECK-GI-NEXT:    addv s1, v1.4s
+; CHECK-GI-NEXT:    movi v2.16b, #1
+; CHECK-GI-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v4.4s, v0.16b, v2.16b
+; CHECK-GI-NEXT:    udot v3.4s, v1.16b, v2.16b
+; CHECK-GI-NEXT:    addv s0, v4.4s
+; CHECK-GI-NEXT:    addv s1, v3.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    fmov w9, s1
 ; CHECK-GI-NEXT:    add w0, w8, w9
@@ -3049,22 +3170,13 @@ define i32 @add_pair_v16i8_v16i32_sext(<16 x i8> %x, <16 x i8> %y) {
 ;
 ; CHECK-GI-LABEL: add_pair_v16i8_v16i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v2.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-GI-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-GI-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-GI-NEXT:    sshll v4.4s, v2.4h, #0
-; CHECK-GI-NEXT:    sshll v5.4s, v0.4h, #0
-; CHECK-GI-NEXT:    sshll v6.4s, v3.4h, #0
-; CHECK-GI-NEXT:    sshll v7.4s, v1.4h, #0
-; CHECK-GI-NEXT:    saddw2 v2.4s, v4.4s, v2.8h
-; CHECK-GI-NEXT:    saddw2 v0.4s, v5.4s, v0.8h
-; CHECK-GI-NEXT:    saddw2 v3.4s, v6.4s, v3.8h
-; CHECK-GI-NEXT:    saddw2 v1.4s, v7.4s, v1.8h
-; CHECK-GI-NEXT:    add v0.4s, v2.4s, v0.4s
-; CHECK-GI-NEXT:    add v1.4s, v3.4s, v1.4s
-; CHECK-GI-NEXT:    addv s0, v0.4s
-; CHECK-GI-NEXT:    addv s1, v1.4s
+; CHECK-GI-NEXT:    movi v2.16b, #1
+; CHECK-GI-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v4.4s, v0.16b, v2.16b
+; CHECK-GI-NEXT:    sdot v3.4s, v1.16b, v2.16b
+; CHECK-GI-NEXT:    addv s0, v4.4s
+; CHECK-GI-NEXT:    addv s1, v3.4s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    fmov w9, s1
 ; CHECK-GI-NEXT:    add w0, w8, w9
@@ -3101,14 +3213,13 @@ define i32 @add_pair_v8i8_v8i32_zext(<8 x i8> %x, <8 x i8> %y) {
 ;
 ; CHECK-GI-LABEL: add_pair_v8i8_v8i32_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v0.4h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v1.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v2.4s, v0.8h
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v3.4s, v1.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
-; CHECK-GI-NEXT:    addv s1, v1.4s
+; CHECK-GI-NEXT:    movi v2.8b, #1
+; CHECK-GI-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v4.2s, v0.8b, v2.8b
+; CHECK-GI-NEXT:    udot v3.2s, v1.8b, v2.8b
+; CHECK-GI-NEXT:    addp v0.2s, v4.2s, v4.2s
+; CHECK-GI-NEXT:    addp v1.2s, v3.2s, v3.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    fmov w9, s1
 ; CHECK-GI-NEXT:    add w0, w8, w9
@@ -3145,14 +3256,13 @@ define i32 @add_pair_v8i8_v8i32_sext(<8 x i8> %x, <8 x i8> %y) {
 ;
 ; CHECK-GI-LABEL: add_pair_v8i8_v8i32_sext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v0.4h, #0
-; CHECK-GI-NEXT:    sshll v3.4s, v1.4h, #0
-; CHECK-GI-NEXT:    saddw2 v0.4s, v2.4s, v0.8h
-; CHECK-GI-NEXT:    saddw2 v1.4s, v3.4s, v1.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
-; CHECK-GI-NEXT:    addv s1, v1.4s
+; CHECK-GI-NEXT:    movi v2.8b, #1
+; CHECK-GI-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-GI-NEXT:    sdot v4.2s, v0.8b, v2.8b
+; CHECK-GI-NEXT:    sdot v3.2s, v1.8b, v2.8b
+; CHECK-GI-NEXT:    addp v0.2s, v4.2s, v4.2s
+; CHECK-GI-NEXT:    addp v1.2s, v3.2s, v3.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
 ; CHECK-GI-NEXT:    fmov w9, s1
 ; CHECK-GI-NEXT:    add w0, w8, w9
@@ -4066,26 +4176,23 @@ define i32 @add_pair_v8i8_v8i32_double_sext_zext(<8 x i8> %ax, <8 x i8> %ay, <8
 ;
 ; CHECK-GI-LABEL: add_pair_v8i8_v8i32_double_sext_zext:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT:    sshll v2.8h, v2.8b, #0
-; CHECK-GI-NEXT:    sshll v3.8h, v3.8b, #0
-; CHECK-GI-NEXT:    ushll v4.4s, v0.4h, #0
-; CHECK-GI-NEXT:    ushll v5.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll v6.4s, v2.4h, #0
-; CHECK-GI-NEXT:    sshll v7.4s, v3.4h, #0
-; CHECK-GI-NEXT:    uaddw2 v0.4s, v4.4s, v0.8h
-; CHECK-GI-NEXT:    uaddw2 v1.4s, v5.4s, v1.8h
-; CHECK-GI-NEXT:    saddw2 v2.4s, v6.4s, v2.8h
-; CHECK-GI-NEXT:    saddw2 v3.4s, v7.4s, v3.8h
-; CHECK-GI-NEXT:    addv s0, v0.4s
-; CHECK-GI-NEXT:    addv s1, v1.4s
-; CHECK-GI-NEXT:    addv s2, v2.4s
-; CHECK-GI-NEXT:    addv s3, v3.4s
+; CHECK-GI-NEXT:    movi v4.8b, #1
+; CHECK-GI-NEXT:    movi v5.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v6.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v7.2d, #0000000000000000
+; CHECK-GI-NEXT:    movi v16.2d, #0000000000000000
+; CHECK-GI-NEXT:    udot v5.2s, v0.8b, v4.8b
+; CHECK-GI-NEXT:    sdot v6.2s, v3.8b, v4.8b
+; CHECK-GI-NEXT:    udot v7.2s, v1.8b, v4.8b
+; CHECK-GI-NEXT:    sdot v16.2s, v2.8b, v4.8b
+; CHECK-GI-NEXT:    addp v0.2s, v5.2s, v5.2s
+; CHECK-GI-NEXT:    addp v3.2s, v6.2s, v6.2s
+; CHECK-GI-NEXT:    addp v1.2s, v7.2s, v7.2s
+; CHECK-GI-NEXT:    addp v2.2s, v16.2s, v16.2s
 ; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    fmov w11, s3
 ; CHECK-GI-NEXT:    fmov w9, s1
 ; CHECK-GI-NEXT:    fmov w10, s2
-; CHECK-GI-NEXT:    fmov w11, s3
 ; CHECK-GI-NEXT:    add w8, w8, w9
 ; CHECK-GI-NEXT:    add w9, w10, w11
 ; CHECK-GI-NEXT:    add w0, w8, w9



More information about the llvm-commits mailing list