[llvm] 21babe4 - [X86] Combine reduce(add (mul x, y)) to VNNI instruction.

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 7 05:14:28 PST 2022


Author: Luo, Yuanke
Date: 2022-01-07T21:12:19+08:00
New Revision: 21babe4db326a4bbac2e317ad50e4f62643e4a1d

URL: https://github.com/llvm/llvm-project/commit/21babe4db326a4bbac2e317ad50e4f62643e4a1d
DIFF: https://github.com/llvm/llvm-project/commit/21babe4db326a4bbac2e317ad50e4f62643e4a1d.diff

LOG: [X86] Combine reduce(add (mul x, y)) to VNNI instruction.

For below C code, we can use VNNI to combine the mul and add operation.
int usdot_prod_qi(unsigned char *restrict a, char *restrict b, int c,
                  int n) {
  int i;
  for (i = 0; i < 32; i++) {
    c += ((int)a[i] * (int)b[i]);
  }
  return c;
}
We didn't support the combine acoss basic block in this patch.

Differential Revision: https://reviews.llvm.org/D116039

Added: 
    llvm/test/CodeGen/X86/dpbusd.ll
    llvm/test/CodeGen/X86/dpbusd_i4.ll

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86PartialReduction.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ab61a0a515981..e2351c765e981 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -41799,6 +41799,40 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// (mul (zext a), (sext, b))
+static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0,
+                         SDValue &Op1) {
+  Op0 = Mul.getOperand(0);
+  Op1 = Mul.getOperand(1);
+
+  // The operand1 should be signed extend
+  if (Op0.getOpcode() == ISD::SIGN_EXTEND)
+    std::swap(Op0, Op1);
+
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND)
+    return false;
+
+  auto IsFreeTruncation = [](SDValue &Op) -> bool {
+    if ((Op.getOpcode() == ISD::ZERO_EXTEND ||
+         Op.getOpcode() == ISD::SIGN_EXTEND) &&
+        Op.getOperand(0).getScalarValueSizeInBits() <= 8)
+      return true;
+
+    // TODO: Support contant value.
+    return false;
+  };
+
+  // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
+  // value, we need to check Op0 is zero extended value. Op1 should be signed
+  // value, so we just check the signed bits.
+  if ((IsFreeTruncation(Op0) &&
+       DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8) &&
+      (IsFreeTruncation(Op1) && DAG.ComputeMaxSignificantBits(Op1) <= 8))
+    return true;
+
+  return false;
+}
+
 // Given a ABS node, detect the following pattern:
 // (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))).
 // This is useful as it is the input into a SAD pattern.
@@ -41820,6 +41854,50 @@ static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) {
   return true;
 }
 
+static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
+                              unsigned &LogBias, const SDLoc &DL,
+                              const X86Subtarget &Subtarget) {
+  // Extend or truncate to MVT::i8 first.
+  MVT Vi8VT =
+      MVT::getVectorVT(MVT::i8, LHS.getValueType().getVectorElementCount());
+  LHS = DAG.getZExtOrTrunc(LHS, DL, Vi8VT);
+  RHS = DAG.getSExtOrTrunc(RHS, DL, Vi8VT);
+
+  // VPDPBUSD(<16 x i32>C, <16 x i8>A, <16 x i8>B). For each dst element
+  // C[0] = C[0] + A[0]B[0] + A[1]B[1] + A[2]B[2] + A[3]B[3].
+  // The src A, B element type is i8, but the dst C element type is i32.
+  // When we calculate the reduce stage, we use src vector type vXi8 for it
+  // so we need logbias 2 to avoid extra 2 stages.
+  LogBias = 2;
+
+  unsigned RegSize = std::max(128u, (unsigned)Vi8VT.getSizeInBits());
+  if (Subtarget.hasVNNI() && !Subtarget.hasVLX())
+    RegSize = std::max(512u, RegSize);
+
+  // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
+  // fill in the missing vector elements with 0.
+  unsigned NumConcat = RegSize / Vi8VT.getSizeInBits();
+  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, Vi8VT));
+  Ops[0] = LHS;
+  MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
+  SDValue DpOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+  Ops[0] = RHS;
+  SDValue DpOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+
+  // Actually build the DotProduct, split as 256/512 bits for
+  // AVXVNNI/AVX512VNNI.
+  auto DpBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                       ArrayRef<SDValue> Ops) {
+    MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
+    return DAG.getNode(X86ISD::VPDPBUSD, DL, VT, Ops);
+  };
+  MVT DpVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
+  SDValue Zero = DAG.getConstant(0, DL, DpVT);
+
+  return SplitOpsAndApply(DAG, Subtarget, DL, DpVT, {Zero, DpOp0, DpOp1},
+                          DpBuilder, false);
+}
+
 // Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
 // to these zexts.
 static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
@@ -42069,6 +42147,77 @@ static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG,
   return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext);
 }
 
+static SDValue combineVPDPBUSDPattern(SDNode *Extract, SelectionDAG &DAG,
+                                      const X86Subtarget &Subtarget) {
+  if (!Subtarget.hasVNNI() && !Subtarget.hasAVXVNNI())
+    return SDValue();
+
+  EVT ExtractVT = Extract->getValueType(0);
+  // Verify the type we're extracting is i32, as the output element type of
+  // vpdpbusd is i32.
+  if (ExtractVT != MVT::i32)
+    return SDValue();
+
+  EVT VT = Extract->getOperand(0).getValueType();
+  if (!isPowerOf2_32(VT.getVectorNumElements()))
+    return SDValue();
+
+  // Match shuffle + add pyramid.
+  ISD::NodeType BinOp;
+  SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
+
+  // We can't combine to vpdpbusd for zext, because each of the 4 multiplies
+  // done by vpdpbusd compute a signed 16-bit product that will be sign extended
+  // before adding into the accumulator.
+  // TODO:
+  // We also need to verify that the multiply has at least 2x the number of bits
+  // of the input. We shouldn't match
+  // (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))).
+  // if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND))
+  //   Root = Root.getOperand(0);
+
+  // If there was a match, we want Root to be a mul.
+  if (!Root || Root.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  // Check whether we have an extend and mul pattern
+  SDValue LHS, RHS;
+  if (!detectExtMul(DAG, Root, LHS, RHS))
+    return SDValue();
+
+  // Create the dot product instruction.
+  SDLoc DL(Extract);
+  unsigned StageBias;
+  SDValue DP = createVPDPBUSD(DAG, LHS, RHS, StageBias, DL, Subtarget);
+
+  // If the original vector was wider than 4 elements, sum over the results
+  // in the DP vector.
+  unsigned Stages = Log2_32(VT.getVectorNumElements());
+  EVT DpVT = DP.getValueType();
+
+  if (Stages > StageBias) {
+    unsigned DpElems = DpVT.getVectorNumElements();
+
+    for (unsigned i = Stages - StageBias; i > 0; --i) {
+      SmallVector<int, 16> Mask(DpElems, -1);
+      for (unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j)
+        Mask[j] = MaskEnd + j;
+
+      SDValue Shuffle =
+          DAG.getVectorShuffle(DpVT, DL, DP, DAG.getUNDEF(DpVT), Mask);
+      DP = DAG.getNode(ISD::ADD, DL, DpVT, DP, Shuffle);
+    }
+  }
+
+  // Return the lowest ExtractSizeInBits bits.
+  EVT ResVT =
+      EVT::getVectorVT(*DAG.getContext(), ExtractVT,
+                       DpVT.getSizeInBits() / ExtractVT.getSizeInBits());
+  DP = DAG.getBitcast(ResVT, DP);
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, DP,
+                     Extract->getOperand(1));
+}
+
 static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
                                       const X86Subtarget &Subtarget) {
   // PSADBW is only supported on SSE2 and up.
@@ -42676,6 +42825,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
     return SAD;
 
+  if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
+    return VPDPBUSD;
+
   // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
   if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
     return Cmp;

diff  --git a/llvm/lib/Target/X86/X86PartialReduction.cpp b/llvm/lib/Target/X86/X86PartialReduction.cpp
index babd923e7496a..4e1bb047f2243 100644
--- a/llvm/lib/Target/X86/X86PartialReduction.cpp
+++ b/llvm/lib/Target/X86/X86PartialReduction.cpp
@@ -13,15 +13,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "X86.h"
+#include "X86TargetMachine.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsX86.h"
-#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Pass.h"
-#include "X86TargetMachine.h"
+#include "llvm/Support/KnownBits.h"
 
 using namespace llvm;
 
@@ -49,7 +50,7 @@ class X86PartialReduction : public FunctionPass {
   }
 
 private:
-  bool tryMAddReplacement(Instruction *Op);
+  bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);
   bool trySADReplacement(Instruction *Op);
 };
 }
@@ -63,7 +64,46 @@ char X86PartialReduction::ID = 0;
 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
                 "X86 Partial Reduction", false, false)
 
-bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
+// This function should be aligned with detectExtMul() in X86ISelLowering.cpp.
+static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul,
+                                 const DataLayout *DL) {
+  if (!ST->hasVNNI() && !ST->hasAVXVNNI())
+    return false;
+
+  Value *LHS = Mul->getOperand(0);
+  Value *RHS = Mul->getOperand(1);
+
+  if (isa<SExtInst>(LHS))
+    std::swap(LHS, RHS);
+
+  if (!isa<ZExtInst>(LHS))
+    return false;
+
+  auto IsFreeTruncation = [&](Value *Op) {
+    if (auto *Cast = dyn_cast<CastInst>(Op)) {
+      if (Cast->getParent() == Mul->getParent() &&
+          (Cast->getOpcode() == Instruction::SExt ||
+           Cast->getOpcode() == Instruction::ZExt) &&
+          Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8)
+        return true;
+    }
+    // TODO: Support constant in ISel.
+    return false;
+  };
+
+  // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
+  // value, we need to check LHS is zero extended value. RHS should be signed
+  // value, so we just check the signed bits.
+  if ((IsFreeTruncation(LHS) &&
+       computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) &&
+      (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8))
+    return true;
+
+  return false;
+}
+
+bool X86PartialReduction::tryMAddReplacement(Instruction *Op,
+                                             bool ReduceInOneBB) {
   if (!ST->hasSSE2())
     return false;
 
@@ -82,6 +122,13 @@ bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
   Value *LHS = Mul->getOperand(0);
   Value *RHS = Mul->getOperand(1);
 
+  // If the target support VNNI, leave it to ISel to combine reduce operation
+  // to VNNI instruction.
+  // TODO: we can support transforming reduce to VNNI intrinsic for across block
+  // in this pass.
+  if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL))
+    return false;
+
   // LHS and RHS should be only used once or if they are the same then only
   // used twice. Only check this when SSE4.1 is enabled and we have zext/sext
   // instructions, otherwise we use punpck to emulate zero extend in stages. The
@@ -300,7 +347,9 @@ bool X86PartialReduction::trySADReplacement(Instruction *Op) {
 
 // Walk backwards from the ExtractElementInst and determine if it is the end of
 // a horizontal reduction. Return the input to the reduction if we find one.
-static Value *matchAddReduction(const ExtractElementInst &EE) {
+static Value *matchAddReduction(const ExtractElementInst &EE,
+                                bool &ReduceInOneBB) {
+  ReduceInOneBB = true;
   // Make sure we're extracting index 0.
   auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
   if (!Index || !Index->isNullValue())
@@ -309,6 +358,8 @@ static Value *matchAddReduction(const ExtractElementInst &EE) {
   const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
   if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
     return nullptr;
+  if (EE.getParent() != BO->getParent())
+    ReduceInOneBB = false;
 
   unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements();
   // Ensure the reduction size is a power of 2.
@@ -321,6 +372,8 @@ static Value *matchAddReduction(const ExtractElementInst &EE) {
     const auto *BO = dyn_cast<BinaryOperator>(Op);
     if (!BO || BO->getOpcode() != Instruction::Add)
       return nullptr;
+    if (EE.getParent() != BO->getParent())
+      ReduceInOneBB = false;
 
     // If this isn't the first add, then it should only have 2 users, the
     // shuffle and another add which we checked in the previous iteration.
@@ -460,9 +513,10 @@ bool X86PartialReduction::runOnFunction(Function &F) {
       if (!EE)
         continue;
 
+      bool ReduceInOneBB;
       // First find a reduction tree.
       // FIXME: Do we need to handle other opcodes than Add?
-      Value *Root = matchAddReduction(*EE);
+      Value *Root = matchAddReduction(*EE, ReduceInOneBB);
       if (!Root)
         continue;
 
@@ -470,7 +524,7 @@ bool X86PartialReduction::runOnFunction(Function &F) {
       collectLeaves(Root, Leaves);
 
       for (Instruction *I : Leaves) {
-        if (tryMAddReplacement(I)) {
+        if (tryMAddReplacement(I, ReduceInOneBB)) {
           MadeChange = true;
           continue;
         }

diff  --git a/llvm/test/CodeGen/X86/dpbusd.ll b/llvm/test/CodeGen/X86/dpbusd.ll
new file mode 100644
index 0000000000000..534aa36215d64
--- /dev/null
+++ b/llvm/test/CodeGen/X86/dpbusd.ll
@@ -0,0 +1,548 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avxvnni | FileCheck %s --check-prefixes=AVXVNNI
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni | FileCheck %s --check-prefixes=AVX512,AVX512VNNI
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni -mattr=+avx512vl | FileCheck %s --check-prefixes=AVX512,AVX512VLVNNI
+
+define i32 @no_dpbusd(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: no_dpbusd:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVXVNNI-NEXT:    vpmovzxbw {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVXVNNI-NEXT:    vpmaddwd %ymm0, %ymm1, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    vzeroupper
+; AVXVNNI-NEXT:    retq
+;
+; AVX512-LABEL: no_dpbusd:
+; AVX512:       # %bb.0: # %entry
+; AVX512-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVX512-NEXT:    vpmovzxbw {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVX512-NEXT:    vpmaddwd %ymm0, %ymm1, %ymm0
+; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX512-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmovd %xmm0, %eax
+; AVX512-NEXT:    addl %edx, %eax
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = zext <16 x i8> %1 to <16 x i32>
+  %3 = bitcast i8* %b to <16 x i8>*
+  %4 = load <16 x i8>, <16 x i8>* %3, align 16
+  %5 = zext <16 x i8> %4 to <16 x i32>
+  %6 = mul nsw <16 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+define i32 @vpdpbusd_mutate(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_mutate:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovdqa (%rsi), %xmm0
+; AVXVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVXVNNI-NEXT:    {vex} vpdpbusd (%rdi), %xmm0, %xmm1
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_mutate:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX512VNNI-NEXT:    vmovdqa (%rsi), %xmm1
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm0, %zmm1, %zmm2
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; AVX512VNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_mutate:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovdqa (%rsi), %xmm0
+; AVX512VLVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VLVNNI-NEXT:    vpdpbusd (%rdi), %xmm0, %xmm1
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VLVNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = sext <16 x i8> %1 to <16 x i32>
+  %3 = bitcast i8* %b to <16 x i8>*
+  %4 = load <16 x i8>, <16 x i8>* %3, align 16
+  %5 = zext <16 x i8> %4 to <16 x i32>
+  %6 = mul nsw <16 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_zext(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: mul_zext:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVXVNNI-NEXT:    vpmovsxbw (%rsi), %ymm1
+; AVXVNNI-NEXT:    vpmullw %ymm0, %ymm1, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVXVNNI-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVXVNNI-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    vzeroupper
+; AVXVNNI-NEXT:    retq
+;
+; AVX512-LABEL: mul_zext:
+; AVX512:       # %bb.0: # %entry
+; AVX512-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVX512-NEXT:    vpmovsxbw (%rsi), %ymm1
+; AVX512-NEXT:    vpmullw %ymm0, %ymm1, %ymm0
+; AVX512-NEXT:    vpmovzxwd {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
+; AVX512-NEXT:    vextracti64x4 $1, %zmm0, %ymm1
+; AVX512-NEXT:    vpaddd %zmm1, %zmm0, %zmm0
+; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmovd %xmm0, %eax
+; AVX512-NEXT:    addl %edx, %eax
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = zext <16 x i8> %1 to <16 x i16>
+  %3 = bitcast i8* %b to <16 x i8>*
+  %4 = load <16 x i8>, <16 x i8>* %3, align 16
+  %5 = sext <16 x i8> %4 to <16 x i16>
+  %6 = mul nsw <16 x i16> %5, %2
+  ; We can't combine to vpdpbusd for zext, because each of the 4 multiplies
+  ; done by vpdpbusd compute a signed 16-bit product that will be sign extended
+  ; before adding into the accumulator.
+  %7 = zext <16 x i16> %6 to <16 x i32>
+  %8 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %7)
+  %op.extra = add nsw i32 %8, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_sext(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: mul_sext:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVXVNNI-NEXT:    vpmovsxbw (%rsi), %ymm1
+; AVXVNNI-NEXT:    vpmullw %ymm0, %ymm1, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpmovsxwd %xmm1, %ymm1
+; AVXVNNI-NEXT:    vpmovsxwd %xmm0, %ymm0
+; AVXVNNI-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    vzeroupper
+; AVXVNNI-NEXT:    retq
+;
+; AVX512-LABEL: mul_sext:
+; AVX512:       # %bb.0: # %entry
+; AVX512-NEXT:    vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
+; AVX512-NEXT:    vpmovsxbw (%rsi), %ymm1
+; AVX512-NEXT:    vpmullw %ymm0, %ymm1, %ymm0
+; AVX512-NEXT:    vpmovsxwd %ymm0, %zmm0
+; AVX512-NEXT:    vextracti64x4 $1, %zmm0, %ymm1
+; AVX512-NEXT:    vpaddd %zmm1, %zmm0, %zmm0
+; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmovd %xmm0, %eax
+; AVX512-NEXT:    addl %edx, %eax
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = zext <16 x i8> %1 to <16 x i16>
+  %3 = bitcast i8* %b to <16 x i8>*
+  %4 = load <16 x i8>, <16 x i8>* %3, align 16
+  %5 = sext <16 x i8> %4 to <16 x i16>
+  %6 = mul nsw <16 x i16> %5, %2
+  ; TODO:
+  ; We also need to verify that the multiply has at least 2x the number of bits
+  ; of the input. We shouldn't match
+  ; (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))).
+  %7 = sext <16 x i16> %6 to <16 x i32>
+  %8 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %7)
+  %op.extra = add nsw i32 %8, %c
+  ret i32 %op.extra
+}
+
+define i32 @vpdpbusd_512(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_512:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovdqa (%rdi), %xmm0
+; AVXVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVXVNNI-NEXT:    {vex} vpdpbusd (%rsi), %xmm0, %xmm1
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_512:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX512VNNI-NEXT:    vmovdqa (%rsi), %xmm1
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm1, %zmm0, %zmm2
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; AVX512VNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_512:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX512VLVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VLVNNI-NEXT:    vpdpbusd (%rsi), %xmm0, %xmm1
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VLVNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = zext <16 x i8> %1 to <16 x i32>
+  %3 = bitcast i8* %b to <16 x i8>*
+  %4 = load <16 x i8>, <16 x i8>* %3, align 16
+  %5 = sext <16 x i8> %4 to <16 x i32>
+  %6 = mul nsw <16 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
+
+define i32 @vpdpbusd_256(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_256:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVXVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVXVNNI-NEXT:    {vex} vpdpbusd %xmm0, %xmm1, %xmm2
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_256:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm0, %zmm1, %zmm2
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1]
+; AVX512VNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX512VNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_256:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VLVNNI-NEXT:    vpdpbusd %xmm0, %xmm1, %xmm2
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX512VLVNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <8 x i8>*
+  %1 = load <8 x i8>, <8 x i8>* %0, align 8
+  %2 = zext <8 x i8> %1 to <8 x i32>
+  %3 = bitcast i8* %b to <8 x i8>*
+  %4 = load <8 x i8>, <8 x i8>* %3, align 8
+  %5 = sext <8 x i8> %4 to <8 x i32>
+  %6 = mul nsw <8 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
+
+define i32 @vpdpbusd_128(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_128:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVXVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVXVNNI-NEXT:    vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3,4,5,6,7]
+; AVXVNNI-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7]
+; AVXVNNI-NEXT:    {vex} vpdpbusd %xmm1, %xmm0, %xmm2
+; AVXVNNI-NEXT:    vmovd %xmm2, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_128:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VNNI-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3,4,5,6,7]
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm2 = mem[0],zero
+; AVX512VNNI-NEXT:    vpblendw {{.*#+}} xmm1 = xmm2[0,1],xmm1[2,3,4,5,6,7]
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm0, %zmm1, %zmm2
+; AVX512VNNI-NEXT:    vmovd %xmm2, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_128:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VLVNNI-NEXT:    vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3,4,5,6,7]
+; AVX512VLVNNI-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7]
+; AVX512VLVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VLVNNI-NEXT:    vpdpbusd %xmm1, %xmm0, %xmm2
+; AVX512VLVNNI-NEXT:    vmovd %xmm2, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <4 x i8>*
+  %1 = load <4 x i8>, <4 x i8>* %0, align 8
+  %2 = zext <4 x i8> %1 to <4 x i32>
+  %3 = bitcast i8* %b to <4 x i8>*
+  %4 = load <4 x i8>, <4 x i8>* %3, align 8
+  %5 = sext <4 x i8> %4 to <4 x i32>
+  %6 = mul nsw <4 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+
+define i32 @vpdpbusd_2xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_2xi32:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVXVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVXVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVXVNNI-NEXT:    vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7]
+; AVXVNNI-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0],xmm2[1,2,3,4,5,6,7]
+; AVXVNNI-NEXT:    {vex} vpdpbusd %xmm1, %xmm0, %xmm2
+; AVXVNNI-NEXT:    vmovd %xmm2, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_2xi32:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VNNI-NEXT:    vmovdqa {{.*#+}} xmm1 = [65535,0,0,0]
+; AVX512VNNI-NEXT:    vpandq %zmm1, %zmm0, %zmm0
+; AVX512VNNI-NEXT:    vmovq {{.*#+}} xmm2 = mem[0],zero
+; AVX512VNNI-NEXT:    vpandq %zmm1, %zmm2, %zmm1
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm0, %zmm1, %zmm2
+; AVX512VNNI-NEXT:    vmovd %xmm2, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_2xi32:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
+; AVX512VLVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VLVNNI-NEXT:    vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7]
+; AVX512VLVNNI-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0],xmm2[1,2,3,4,5,6,7]
+; AVX512VLVNNI-NEXT:    vpdpbusd %xmm1, %xmm0, %xmm2
+; AVX512VLVNNI-NEXT:    vmovd %xmm2, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <2 x i8>*
+  %1 = load <2 x i8>, <2 x i8>* %0, align 8
+  %2 = zext <2 x i8> %1 to <2 x i32>
+  %3 = bitcast i8* %b to <2 x i8>*
+  %4 = load <2 x i8>, <2 x i8>* %3, align 8
+  %5 = sext <2 x i8> %4 to <2 x i32>
+  %6 = mul nsw <2 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v2i32(<2 x i32>)
+
+define i32 @vpdpbusd_32xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_32xi32:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovdqu (%rdi), %ymm0
+; AVXVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVXVNNI-NEXT:    {vex} vpdpbusd (%rsi), %ymm0, %ymm1
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm1, %xmm0
+; AVXVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    vzeroupper
+; AVXVNNI-NEXT:    retq
+;
+; AVX512VNNI-LABEL: vpdpbusd_32xi32:
+; AVX512VNNI:       # %bb.0: # %entry
+; AVX512VNNI-NEXT:    vmovdqu (%rdi), %ymm0
+; AVX512VNNI-NEXT:    vmovdqu (%rsi), %ymm1
+; AVX512VNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512VNNI-NEXT:    vpdpbusd %zmm1, %zmm0, %zmm2
+; AVX512VNNI-NEXT:    vextracti128 $1, %ymm2, %xmm0
+; AVX512VNNI-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VNNI-NEXT:    addl %edx, %eax
+; AVX512VNNI-NEXT:    vzeroupper
+; AVX512VNNI-NEXT:    retq
+;
+; AVX512VLVNNI-LABEL: vpdpbusd_32xi32:
+; AVX512VLVNNI:       # %bb.0: # %entry
+; AVX512VLVNNI-NEXT:    vmovdqu (%rdi), %ymm0
+; AVX512VLVNNI-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VLVNNI-NEXT:    vpdpbusd (%rsi), %ymm0, %ymm1
+; AVX512VLVNNI-NEXT:    vextracti128 $1, %ymm1, %xmm0
+; AVX512VLVNNI-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VLVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512VLVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512VLVNNI-NEXT:    vmovd %xmm0, %eax
+; AVX512VLVNNI-NEXT:    addl %edx, %eax
+; AVX512VLVNNI-NEXT:    vzeroupper
+; AVX512VLVNNI-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <32 x i8>*
+  %1 = load <32 x i8>, <32 x i8>* %0, align 16
+  %2 = zext <32 x i8> %1 to <32 x i32>
+  %3 = bitcast i8* %b to <32 x i8>*
+  %4 = load <32 x i8>, <32 x i8>* %3, align 16
+  %5 = sext <32 x i8> %4 to <32 x i32>
+  %6 = mul nsw <32 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v32i32(<32 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v32i32(<32 x i32>)
+
+define i32 @vpdpbusd_64xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) {
+; AVXVNNI-LABEL: vpdpbusd_64xi32:
+; AVXVNNI:       # %bb.0: # %entry
+; AVXVNNI-NEXT:    vmovdqu (%rdi), %ymm0
+; AVXVNNI-NEXT:    vmovdqu 32(%rdi), %ymm1
+; AVXVNNI-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVXVNNI-NEXT:    vpxor %xmm3, %xmm3, %xmm3
+; AVXVNNI-NEXT:    {vex} vpdpbusd 32(%rsi), %ymm1, %ymm3
+; AVXVNNI-NEXT:    {vex} vpdpbusd (%rsi), %ymm0, %ymm2
+; AVXVNNI-NEXT:    vpaddd %ymm3, %ymm2, %ymm0
+; AVXVNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVXVNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVXVNNI-NEXT:    vmovd %xmm0, %eax
+; AVXVNNI-NEXT:    addl %edx, %eax
+; AVXVNNI-NEXT:    vzeroupper
+; AVXVNNI-NEXT:    retq
+;
+; AVX512-LABEL: vpdpbusd_64xi32:
+; AVX512:       # %bb.0: # %entry
+; AVX512-NEXT:    vmovdqu64 (%rdi), %zmm0
+; AVX512-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512-NEXT:    vpdpbusd (%rsi), %zmm0, %zmm1
+; AVX512-NEXT:    vextracti64x4 $1, %zmm1, %ymm0
+; AVX512-NEXT:    vpaddd %zmm0, %zmm1, %zmm0
+; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; AVX512-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmovd %xmm0, %eax
+; AVX512-NEXT:    addl %edx, %eax
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <64 x i8>*
+  %1 = load <64 x i8>, <64 x i8>* %0, align 16
+  %2 = zext <64 x i8> %1 to <64 x i32>
+  %3 = bitcast i8* %b to <64 x i8>*
+  %4 = load <64 x i8>, <64 x i8>* %3, align 16
+  %5 = sext <64 x i8> %4 to <64 x i32>
+  %6 = mul nsw <64 x i32> %5, %2
+  %7 = call i32 @llvm.vector.reduce.add.v64i32(<64 x i32> %6)
+  %op.extra = add nsw i32 %7, %c
+  ret i32 %op.extra
+}
+
+declare i32 @llvm.vector.reduce.add.v64i32(<64 x i32>)

diff  --git a/llvm/test/CodeGen/X86/dpbusd_i4.ll b/llvm/test/CodeGen/X86/dpbusd_i4.ll
new file mode 100644
index 0000000000000..1f259211dada3
--- /dev/null
+++ b/llvm/test/CodeGen/X86/dpbusd_i4.ll
@@ -0,0 +1,131 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni -mattr=+avx512vl | FileCheck %s
+
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
+
+define i32 @mul_i8i8(i8 *%a, <16 x i8> %b, i32 %c) {
+; CHECK-LABEL: mul_i8i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vmovdqa (%rdi), %xmm1
+; CHECK-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; CHECK-NEXT:    vpdpbusd %xmm0, %xmm1, %xmm2
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; CHECK-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovd %xmm0, %eax
+; CHECK-NEXT:    addl %esi, %eax
+; CHECK-NEXT:    retq
+entry:
+  %0 = bitcast i8* %a to <16 x i8>*
+  %1 = load <16 x i8>, <16 x i8>* %0, align 16
+  %2 = zext <16 x i8> %1 to <16 x i32>
+  %3 = sext <16 x i8> %b to <16 x i32>
+  %4 = mul nsw <16 x i32> %2, %3
+  %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+  %op.extra = add nsw i32 %5, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_i4i8(<16 x i4> %a, <16 x i8> %b, i32 %c) {
+; CHECK-LABEL: mul_i4i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; CHECK-NEXT:    vpdpbusd %xmm1, %xmm0, %xmm2
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; CHECK-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovd %xmm0, %eax
+; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    retq
+entry:
+  %0 = zext <16 x i4> %a to <16 x i32>
+  %1 = sext <16 x i8> %b to <16 x i32>
+  %2 = mul nsw <16 x i32> %0, %1
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  %op.extra = add nsw i32 %3, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) {
+; CHECK-LABEL: mul_i4i4:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vpsllw $4, %xmm1, %xmm1
+; CHECK-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; CHECK-NEXT:    vpsrlw $4, %xmm1, %xmm1
+; CHECK-NEXT:    vmovdqa {{.*#+}} xmm2 = [8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8]
+; CHECK-NEXT:    vpxor %xmm2, %xmm1, %xmm1
+; CHECK-NEXT:    vpsubb %xmm2, %xmm1, %xmm1
+; CHECK-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; CHECK-NEXT:    vpdpbusd %xmm1, %xmm0, %xmm2
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; CHECK-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovd %xmm0, %eax
+; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    retq
+entry:
+  %0 = zext <16 x i4> %a to <16 x i32>
+  %1 = sext <16 x i4> %b to <16 x i32>
+  %2 = mul nsw <16 x i32> %0, %1
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  %op.extra = add nsw i32 %3, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_sext_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) {
+; CHECK-LABEL: mul_sext_i4i4:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vpmovzxbw {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero,xmm0[8],zero,xmm0[9],zero,xmm0[10],zero,xmm0[11],zero,xmm0[12],zero,xmm0[13],zero,xmm0[14],zero,xmm0[15],zero
+; CHECK-NEXT:    vpmovzxbw {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero,xmm1[8],zero,xmm1[9],zero,xmm1[10],zero,xmm1[11],zero,xmm1[12],zero,xmm1[13],zero,xmm1[14],zero,xmm1[15],zero
+; CHECK-NEXT:    vpsllw $12, %ymm1, %ymm1
+; CHECK-NEXT:    vpsraw $12, %ymm1, %ymm1
+; CHECK-NEXT:    vpsllw $12, %ymm0, %ymm0
+; CHECK-NEXT:    vpsraw $12, %ymm0, %ymm0
+; CHECK-NEXT:    vpmaddwd %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; CHECK-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovd %xmm0, %eax
+; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = sext <16 x i4> %a to <16 x i32>
+  %1 = sext <16 x i4> %b to <16 x i32>
+  %2 = mul nsw <16 x i32> %0, %1
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  %op.extra = add nsw i32 %3, %c
+  ret i32 %op.extra
+}
+
+define i32 @mul_zext_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) {
+; CHECK-LABEL: mul_zext_i4i4:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vmovdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
+; CHECK-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; CHECK-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; CHECK-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; CHECK-NEXT:    vpdpbusd %xmm1, %xmm0, %xmm2
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
+; CHECK-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; CHECK-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
+; CHECK-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovd %xmm0, %eax
+; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    retq
+entry:
+  %0 = zext <16 x i4> %a to <16 x i32>
+  %1 = zext <16 x i4> %b to <16 x i32>
+  %2 = mul nsw <16 x i32> %0, %1
+  %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2)
+  %op.extra = add nsw i32 %3, %c
+  ret i32 %op.extra
+}


        


More information about the llvm-commits mailing list