[llvm] [AArch64] Generate DOT instructions from matching IR (PR #69583)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 19 03:14:45 PDT 2023


https://github.com/huntergr-arm created https://github.com/llvm/llvm-project/pull/69583

This pass matches sequences of extend->mul->accumulate and replaces
them with DOT intrinsics.

This currently only supports SVE and scalable vectors.


>From 2b6da683e001ba852674d0f55cc5beb95c14782f Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Fri, 13 Oct 2023 14:09:55 +0100
Subject: [PATCH] [AArch64] Generate DOT instructions from matching IR

This pass matches sequences of extend->mul->accumulate and replaces
them with DOT intrinsics.

This currently only supports SVE and scalable vectors.
---
 llvm/lib/Target/AArch64/AArch64.h             |   2 +
 .../Target/AArch64/AArch64DotProdMatcher.cpp  | 486 +++++++++++++
 .../Target/AArch64/AArch64TargetMachine.cpp   |  11 +
 llvm/lib/Target/AArch64/CMakeLists.txt        |   1 +
 llvm/test/CodeGen/AArch64/O3-pipeline.ll      |   1 +
 llvm/test/CodeGen/AArch64/dotprodmatch.ll     | 684 ++++++++++++++++++
 6 files changed, 1185 insertions(+)
 create mode 100644 llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
 create mode 100644 llvm/test/CodeGen/AArch64/dotprodmatch.ll

diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 901769c54b6ef59..afdc8e3698b2d99 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -71,6 +71,7 @@ FunctionPass *createAArch64PostSelectOptimize();
 FunctionPass *createAArch64StackTaggingPass(bool IsOptNone);
 FunctionPass *createAArch64StackTaggingPreRAPass();
 ModulePass *createAArch64GlobalsTaggingPass();
+FunctionPass *createAArch64DotProdMatcherPass();
 
 void initializeAArch64A53Fix835769Pass(PassRegistry&);
 void initializeAArch64A57FPLoadBalancingPass(PassRegistry&);
@@ -108,6 +109,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
 void initializeLDTLSCleanupPass(PassRegistry&);
 void initializeSMEABIPass(PassRegistry &);
 void initializeSVEIntrinsicOptsPass(PassRegistry &);
+void initializeAArch64DotProdMatcherPass(PassRegistry &);
 } // end namespace llvm
 
 #endif
diff --git a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
new file mode 100644
index 000000000000000..44215efee75c33c
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
@@ -0,0 +1,486 @@
+//===- AArch64DotProdMatcher - Matches instruction sequences to *DOT ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass recognizes and transforms IR to make use of two relatively simple
+// cases that can be implemented by the SDOT and UDOT instructions on AArch64
+// in order to increase vector unit bandwidth.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/InstructionCost.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "Utils/AArch64BaseInfo.h"
+#include <deque>
+#include <optional>
+#include <tuple>
+#include <utility>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+#define DEBUG_TYPE "aarch64-dot-product-matcher"
+
+#define DOT_ACCUMULATOR_DEPTH (4)
+
+STATISTIC(NumDOTInstrs, "Number of DOT Instructions generated.");
+STATISTIC(NumSimpleDOTReplacements, "Num of simple dot patterns replaced.");
+STATISTIC(NumLoopDOTReplacements, "Num of loop dot patterns replaced.");
+
+struct LoopAccumulate {
+  Value *RVal;
+  PHINode *Phi;
+  Value *IterVals;
+  Value *Predicate;
+  Value *Mul;
+  Value *ValA;
+  Value *ValB;
+  VectorType *VTy;
+  Type *AccTy;
+  BasicBlock *LoopBlock;
+  BasicBlock *PHBlock;
+  bool IsSExt;
+
+  LoopAccumulate(Value *RVal, PHINode *Phi, Value *IterVals, Value *Predicate,
+                 Value *Mul, Value *ValA, Value *ValB, VectorType *VTy,
+                 Type *AccTy, BasicBlock *LoopBlock, BasicBlock *PHBlock,
+                 bool IsSExt)
+    : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate),
+    Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy), LoopBlock(LoopBlock),
+    PHBlock(PHBlock), IsSExt(IsSExt) {}
+};
+
+// Returns true if the instruction in question is an vector integer add
+// reduction intrinsic.
+static bool isScalableIntegerSumReduction(Instruction &I) {
+  auto *II = dyn_cast<IntrinsicInst>(&I);
+  return II &&
+         II->getIntrinsicID() == Intrinsic::vector_reduce_add &&
+         isa<ScalableVectorType>(II->getOperand(0)->getType());
+}
+
+// Returns a vector type for a dot product accumulator if the element type and
+// extended element type are suitable, or a nullptr if not.
+static Type *getAccumulatorType(Type *EltTy, Type *ExtEltTy, ElementCount EC) {
+  Type *AccEltTy = nullptr;
+  if (EltTy->isIntegerTy(8) && ExtEltTy->getPrimitiveSizeInBits() <= 32)
+    AccEltTy = Type::getInt32Ty(EltTy->getContext());
+  else if (EltTy->isIntegerTy(16) && ExtEltTy->getPrimitiveSizeInBits() <= 64)
+    AccEltTy = Type::getInt64Ty(EltTy->getContext());
+
+  if (AccEltTy)
+    return VectorType::get(AccEltTy, EC);
+
+  return nullptr;
+}
+
+// Returns either a pair of basic block pointers corresponding to the expected
+// two incoming values for the phi, or None if one of the checks failed.
+static std::optional<std::pair<BasicBlock*, BasicBlock*>>
+getPHIIncomingBlocks(PHINode *Phi) {
+  // Check PHI; we're only expecting the incoming value from within the loop
+  // and one incoming value from a preheader.
+  if (Phi->getNumIncomingValues() != 2)
+    return std::nullopt;
+
+  BasicBlock *PHBlock = Phi->getIncomingBlock(0);
+  BasicBlock *LoopBlock = Phi->getIncomingBlock(1);
+  // If this isn't a loop, or if it's a loop with multiple blocks, we bail
+  // out for now. If needed we can improve this pass later.
+  if (Phi->getParent() != LoopBlock && Phi->getParent() != PHBlock)
+    return std::nullopt;
+
+  // Make sure we know which incoming value belongs to the loop
+  if (PHBlock == Phi->getParent())
+    std::swap(LoopBlock, PHBlock);
+
+  // If there's a non-null incoming value from the preheader, bail out for now.
+  // We may be able to do better in future.
+  Constant *Const = dyn_cast<Constant>(Phi->getIncomingValueForBlock(PHBlock));
+  if (LoopBlock != Phi->getParent() || !Const || !Const->isNullValue())
+    return std::nullopt;
+
+  return std::make_pair(LoopBlock, PHBlock);
+}
+
+static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
+                         SmallVectorImpl<LoopAccumulate> &Accumulators) {
+  // Check a possible loop accumulator.
+  bool IsSExt = false;
+
+  // We only expect the add in the loop to be used by the reduction and by
+  // the PHI node.
+  if (!RVal->hasNUses(2) || !is_contained(OldPHI->incoming_values(), RVal)) {
+    LLVM_DEBUG(dbgs() << "Loop sum operation has more than two uses or isn't "
+                         "used by the accumulating PHI node.\n");
+    return false;
+  }
+
+  // Look through selects with zeroinitializer. Record the predicate so
+  // we can insert selects for the base values later.
+  Value *Predicate = nullptr, *Mul = nullptr;
+  if (!match(IterVals, m_Select(m_Value(Predicate), m_Value(Mul), m_Zero())))
+    Mul = IterVals;
+
+  Value *ValA = nullptr, *ValB = nullptr;
+  // Match the core pattern of element-wise multiplication of extended values.
+  if (match(Mul, m_OneUse(m_Mul(m_SExt(m_OneUse(m_Value(ValA))),
+                                m_SExt(m_OneUse(m_Value(ValB)))))))
+    IsSExt = true;
+  else if (!match(Mul, m_OneUse(m_Mul(m_ZExt(m_OneUse(m_Value(ValA))),
+                                      m_ZExt(m_OneUse(m_Value(ValB))))))) {
+    LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: "
+                      << *Mul << "\n");
+    return false;
+  }
+
+  // The same extended value could be used for both operands of the multiply,
+  // so we just need to check that they have a single user.
+  Instruction *I = dyn_cast<Instruction>(Mul);
+  if (!I->getOperand(0)->hasOneUser() || !I->getOperand(1)->hasOneUser())
+    return false;
+
+  // Check that the vector type is one packed vector's worth of data.
+  // TODO: Do we want to allow multiples?
+  VectorType *ValTy = cast<VectorType>(ValA->getType());
+  if (ValTy->getPrimitiveSizeInBits().getKnownMinValue() !=
+      AArch64::SVEBitsPerBlock) {
+    LLVM_DEBUG(dbgs() << "Vector base size is not a packed representation.\n");
+    return false;
+  }
+
+  // Find the accumulator element type after extension and check that it isn't
+  // too large; if it is, we might lose data by converting to dot instructions.
+  // The element count needs to be 1/4th that of the input data, since the
+  // dot product instructions take four smaller elements and multiply/accumulate
+  // them into one larger element.
+  Type *AccTy = getAccumulatorType(ValTy->getElementType(),
+      Mul->getType()->getScalarType(),
+      ValTy->getElementCount().divideCoefficientBy(4));
+
+  if (!AccTy) {
+    LLVM_DEBUG(dbgs() << "Accumulator element type too wide.\n");
+    return false;
+  }
+
+  // Validate the phi node and retrieve the incoming basic blocks for the
+  // accumulating loop itself and the preheader.
+  auto PhiBlocks = getPHIIncomingBlocks(OldPHI);
+
+  if (!PhiBlocks) {
+    LLVM_DEBUG(dbgs() << "Unable to match PHI node\n");
+    return false;
+  }
+
+  // Everything looks in order, so add it to the list of accumulators to
+  // transform.
+  Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA,
+                            ValB, ValTy, AccTy, PhiBlocks->first,
+                            PhiBlocks->second, IsSExt);
+  return true;
+}
+
+static bool findDOTAccumulatorsInLoop(Value *RVal,
+                                SmallVectorImpl<LoopAccumulate> &Accumulators,
+                                unsigned Depth = DOT_ACCUMULATOR_DEPTH) {
+  // Don't recurse too far.
+  if (Depth == 0)
+    return false;
+
+  Value *V1 = nullptr, *V2 = nullptr;
+
+  // Try to match the expected pattern from a sum reduction in
+  // a vectorized loop.
+  if (match(RVal, m_Add(m_Value(V1), m_Value(V2)))) {
+    if (isa<PHINode>(V1) && !isa<PHINode>(V2) &&
+        V1->hasOneUse() && V2->hasOneUse())
+      return checkLoopAcc(RVal, cast<PHINode>(V1), V2, Accumulators);
+
+    if (!isa<PHINode>(V1) && isa<PHINode>(V2) &&
+        V1->hasOneUse() && V2->hasOneUse())
+      return checkLoopAcc(RVal, cast<PHINode>(V2), V1, Accumulators);
+
+    // Otherwise assume this is an intermediate multi-register reduction
+    // and recurse to the operands.
+    return findDOTAccumulatorsInLoop(V1, Accumulators, Depth - 1) &&
+           findDOTAccumulatorsInLoop(V2, Accumulators, Depth - 1);
+  }
+
+  return false;
+}
+
+namespace {
+
+class AArch64DotProdMatcher : public FunctionPass {
+public:
+  static char ID;
+  AArch64DotProdMatcher() : FunctionPass(ID) {
+    initializeAArch64DotProdMatcherPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+
+    bool Changed = false;
+    SmallVector<Instruction *, 4> Reductions;
+    for (BasicBlock &Block : F)
+      // TODO: Support non-scalable dot instructions too.
+      for (Instruction &I : make_filter_range(Block,
+                                              isScalableIntegerSumReduction))
+        Reductions.push_back(&I);
+
+    for (auto *Rdx : Reductions)
+      Changed |= trySimpleDotReplacement(*Rdx) || tryLoopDotReplacement(*Rdx);
+
+    return Changed;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+    AU.setPreservesCFG();
+  }
+
+  TargetTransformInfo *TTI;
+
+private:
+  bool trySimpleDotReplacement(Instruction &I);
+  bool tryLoopDotReplacement(Instruction &I);
+};
+
+} // end anonymous namespace
+
+char AArch64DotProdMatcher::ID = 0;
+INITIALIZE_PASS_BEGIN(AArch64DotProdMatcher, DEBUG_TYPE,
+                "AArch64 Dot Product Instruction Matcher", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_END(AArch64DotProdMatcher, DEBUG_TYPE,
+                "AArch64 Dot Product Instruction Matcher", false, false)
+
+FunctionPass *llvm::createAArch64DotProdMatcherPass() {
+  return new AArch64DotProdMatcher();
+}
+
+// The following method looks for a simple pattern of two values being either
+// sign or zero extended, multiplied together, then summed. If the types
+// match the ones used by the [s|u]dot instructions (groups of 4x8 -> 32,
+// groups of 4x16 -> 64) then we can replace the extends and multiply with a
+// dot instruction and swap the reduce for one using fewer elements.
+//
+//      +-----------+   +-----------+
+//      |   ValA    |   |   ValB    |
+//      +-----+-----+   +-----+-----+
+//            |               |
+//            |               |
+//      +-----v-----+   +-----v-----+
+//      | [S|Z]Ext  |   | [S|Z]Ext  |
+//      +-----+-----+   +-----+-----+
+//            |               |
+//            +--+         +--+
+//               |         |
+//              +v---------v+
+//              |    Mul    |
+//              +-----+-----+
+//                    |
+//                    |
+//              +-----v-----+
+//              | Reduce(+) |
+//              +-----------+
+bool AArch64DotProdMatcher::trySimpleDotReplacement(Instruction &I) {
+  LLVM_DEBUG(dbgs() << "Looking for simple dot reduction: " << I << "\n");
+  Value *RVal = I.getOperand(0);
+  Value *ValA = nullptr, *ValB = nullptr;
+  bool IsSExt = false;
+
+  if (match(RVal, m_Mul(m_SExt(m_Value(ValA)), m_SExt(m_Value(ValB)))))
+    IsSExt = true;
+  else if (!match(RVal, m_Mul(m_ZExt(m_Value(ValA)), m_ZExt(m_Value(ValB))))) {
+    LLVM_DEBUG(dbgs() << "Unable to match simple dot pattern\n");
+    return false;
+  }
+
+  VectorType *ATy = cast<VectorType>(ValA->getType());
+  VectorType *BTy = cast<VectorType>(ValB->getType());
+  VectorType *MTy = cast<VectorType>(RVal->getType());
+  if (ATy != BTy || !((ATy->getScalarType()->isIntegerTy(8) &&
+                       MTy->getScalarType()->isIntegerTy(32)) ||
+                      (ATy->getScalarType()->isIntegerTy(16) &&
+                       MTy->getScalarType()->isIntegerTy(64)))) {
+    LLVM_DEBUG(dbgs() << "Unable to match types for simple dot pattern\n");
+    return false;
+  }
+
+  if (TTI->getRegisterBitWidth(TargetTransformInfo::RGK_ScalableVector) !=
+      ATy->getPrimitiveSizeInBits())
+    return false;
+
+  // All conditions met, proceed with replacement.
+  IRBuilder<> Builder(cast<Instruction>(RVal));
+
+  // Need a new accumulator type.
+  Type *AccTy = VectorType::get(MTy->getScalarType(),
+                                MTy->getElementCount().divideCoefficientBy(4));
+  Value *Zeroes = ConstantAggregateZero::get(AccTy);
+
+  Intrinsic::ID IntID = IsSExt ? Intrinsic::aarch64_sve_sdot :
+                                 Intrinsic::aarch64_sve_udot;
+  Value *DotProd = Builder.CreateIntrinsic(IntID, {AccTy},
+                                           {Zeroes, ValA, ValB});
+  Builder.SetInsertPoint(&I);
+  Value *Reduce = Builder.CreateAddReduce(DotProd);
+  I.replaceAllUsesWith(Reduce);
+  NumDOTInstrs++;
+  NumSimpleDOTReplacements++;
+  return true;
+}
+
+// This method looks for the following pattern: It starts from a sum
+// reduction, but expects to find a vector add operation inside a loop with one
+// of the operands being a PHI. The other operand can either be a select
+// between zeroes and a multiply, or just the multiply directly. The rest of
+// the pattern is the same as the simpler case -- multiply of extends of some
+// values.
+//
+// Replacing this is a little tricky, since we need to replace the PHI node
+// and accumulator as well, and potentially add in new selects earlier, but if
+// everything checks out then the extend -> multiply -> inner loop add operation
+// is replaced by the [s|u]dot instruction.
+//
+//                                     +-----------+
+//                                     |   Zero    |
+//                                     +-+---------+
+//  +-------+      +---------------------+   |
+//  |       |      |                         |
+//  |    +--v------v-+                       |
+//  |    |  OldPHI   |                       |
+//  |    +--+--------+                       |
+//  |       |                                |
+//  |       |   +-----------+   +-----------+|
+//  |       |   |   ValA    |   |   ValB    ||
+//  |       |   +-----+-----+   +-----+-----+|
+//  |       |         |               |      |
+//  |       |         |               |      |
+//  |       |   +-----v-----+   +-----v-----+|
+//  |       |   | [S|Z]Ext  |   | [S|Z]Ext  ||
+//  |       |   +-----+-----+   +-----+-----+|
+//  |       |         |               |      |
+//  |       |         +--+         +--+      |
+//  |       |            |         |         |
+//  |       |           +v---------v+        |
+//  |       |           |    Mul    |        |
+//  |       |           +-+---------+        |
+//  |       |             |       +----------+
+//  |       |             |       |
+//  |       |           +-v-------v-+
+//  |       |           |  Select   |
+//  |       |           +--+--------+
+//  |       |              |
+//  |       |              |
+//  |       |              |
+//  |    +--v--------------v---+
+//  |    |         Add         |
+//  |    +--+-------+----------+
+//  |       |       |
+//  +-------+       |
+//                  |
+//            +-----v-----+
+//            | Reduce(+) |
+//            +-----------+
+bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
+  LLVM_DEBUG(dbgs() << "Looking for Loop DOT Reduction: " << I << "\n");
+  Value *RVal = I.getOperand(0);
+  SmallVector<LoopAccumulate, 4> Accumulators;
+  std::deque<Value *> RdxVals;
+  IRBuilder<> Builder(&I);
+
+  // If the loop was interleaved, we may have some intermediate add
+  // instructions first before we get to the accumulators inside the
+  // loop. Gather those first then process them.
+  if (!findDOTAccumulatorsInLoop(RVal, Accumulators)) {
+    LLVM_DEBUG(dbgs() << "Couldn't find DOT accumulators in the loop\n");
+    return false;
+  }
+
+  // All conditions met, proceed with replacement.
+  for (auto &Acc : Accumulators) {
+    Builder.SetInsertPoint(Acc.Phi);
+
+    // Plant new PHI node.
+    PHINode *DotAcc = Builder.CreatePHI(Acc.AccTy, 2, "dot.accumulate");
+    Value *Zeroes = ConstantAggregateZero::get(Acc.AccTy);
+    DotAcc->addIncoming(Zeroes, Acc.PHBlock);
+
+    // Move to the dot insertion point.
+    Builder.SetInsertPoint(cast<Instruction>(Acc.RVal));
+
+    // Need to generate selects for ValA and ValB if there was one before the
+    // accumulate before.
+    // Hopefully we can fold away some extra selects (e.g. if the data originally
+    // came from masked loads with the same predicate).
+    if (Acc.Predicate) {
+      Value *Zeroes = ConstantAggregateZero::get(Acc.VTy);
+      Acc.ValA = Builder.CreateSelect(Acc.Predicate, Acc.ValA, Zeroes);
+      Acc.ValB = Builder.CreateSelect(Acc.Predicate, Acc.ValB, Zeroes);
+    }
+
+    // Now plant the dot instruction.
+    Intrinsic::ID IntID = Acc.IsSExt ? Intrinsic::aarch64_sve_sdot :
+                                          Intrinsic::aarch64_sve_udot;
+    Value *DotProd = Builder.CreateIntrinsic(IntID, {Acc.AccTy},
+                                             {DotAcc, Acc.ValA, Acc.ValB});
+    DotAcc->addIncoming(DotProd, Acc.LoopBlock);
+
+    RdxVals.push_back(DotProd);
+
+    NumDOTInstrs++;
+  }
+
+  assert(!RdxVals.empty() &&
+         "We found accumulators but generated no RdxVals");
+
+
+  Builder.SetInsertPoint(cast<Instruction>(RVal));
+
+  while (RdxVals.size() > 1) {
+    RdxVals.push_back(Builder.CreateAdd(RdxVals[0], RdxVals[1]));
+    // Drop the two RdxVals we just reduced. Sadly, there's no SmallDeque
+    // with a pop_front_val() convenience method yet.
+    RdxVals.pop_front();
+    RdxVals.pop_front();
+  }
+
+  // Plant new reduction.
+  Builder.SetInsertPoint(&I);
+  Value *Reduce = Builder.CreateAddReduce(RdxVals.front());
+  Value *Trunc = Builder.CreateTrunc(Reduce, I.getType(), "dot.trunc");
+  I.replaceAllUsesWith(Trunc);
+
+
+  // Delete the original reduction, since it's no longer required
+  RecursivelyDeleteTriviallyDeadInstructions(&I);
+  NumLoopDOTReplacements++;
+  return true;
+}
+
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 3d818c76bd4b7d7..4a76d2f705a5a13 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -165,6 +165,11 @@ static cl::opt<bool>
                            cl::desc("Enable SVE intrinsic opts"),
                            cl::init(true));
 
+static cl::opt<bool>
+EnableAArch64DotProdMatch("aarch64-enable-dotprodmatch", cl::Hidden,
+                          cl::desc("Enable matching dot product instructions"),
+                          cl::init(true));
+
 static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
                                          cl::init(true), cl::Hidden);
 
@@ -246,6 +251,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
   initializeAArch64LowerHomogeneousPrologEpilogPass(*PR);
   initializeAArch64DAGToDAGISelPass(*PR);
   initializeAArch64GlobalsTaggingPass(*PR);
+  initializeAArch64DotProdMatcherPass(*PR);
 }
 
 //===----------------------------------------------------------------------===//
@@ -553,6 +559,11 @@ void AArch64PassConfig::addIRPasses() {
   // ourselves.
   addPass(createAtomicExpandPass());
 
+  // Make use of SVE intrinsics in place of common vector operations that span
+  // multiple basic blocks.
+  if (TM->getOptLevel() != CodeGenOptLevel::None && EnableAArch64DotProdMatch)
+    addPass(createAArch64DotProdMatcherPass());
+
   // Expand any SVE vector library calls that we can't code generate directly.
   if (EnableSVEIntrinsicOpts &&
       TM->getOptLevel() == CodeGenOptLevel::Aggressive)
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index d97342b0829d826..b89ce94b9312277 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -50,6 +50,7 @@ add_llvm_target(AArch64CodeGen
   AArch64CondBrTuning.cpp
   AArch64ConditionalCompares.cpp
   AArch64DeadRegisterDefinitionsPass.cpp
+  AArch64DotProdMatcher.cpp
   AArch64ExpandImm.cpp
   AArch64ExpandPseudoInsts.cpp
   AArch64FalkorHWPFFix.cpp
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index f5c1c3c291cb585..7d196b8579d202b 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -22,6 +22,7 @@
 ; CHECK-NEXT:       Expand large div/rem
 ; CHECK-NEXT:       Expand large fp convert
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       AArch64 Dot Product Instruction Matcher
 ; CHECK-NEXT:     SVE intrinsics optimizations
 ; CHECK-NEXT:       FunctionPass Manager
 ; CHECK-NEXT:         Dominator Tree Construction
diff --git a/llvm/test/CodeGen/AArch64/dotprodmatch.ll b/llvm/test/CodeGen/AArch64/dotprodmatch.ll
new file mode 100644
index 000000000000000..a75048351b81030
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/dotprodmatch.ll
@@ -0,0 +1,684 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt -S -aarch64-dot-product-matcher -instcombine < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define i16 @sve_sdot_loop_i16_to_i32(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define i16 @sve_sdot_loop_i16_to_i32
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       min.iters.checked:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[PREDICATE_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PREDICATE:%.*]] = phi <vscale x 8 x i1> [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP0]], i32 2, <vscale x 8 x i1> [[PREDICATE]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD19:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP1]], i32 2, <vscale x 8 x i1> [[PREDICATE]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP2]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD19]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD]])
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VS_SCALED:%.*]] = shl i64 [[VS]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]]
+; CHECK-NEXT:    [[PREDICATE_NEXT]] = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 8 x i1> [[PREDICATE_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP2]])
+; CHECK-NEXT:    [[PHITMP201:%.*]] = lshr i64 [[TMP4]], 16
+; CHECK-NEXT:    [[PHITMP:%.*]] = trunc i64 [[PHITMP201]] to i16
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret i16 [[ACC_0_LCSSA]]
+;
+entry:
+  %cmp11 = icmp sgt i32 %N, 0
+  br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup
+
+min.iters.checked:                                ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %wide.end.idx.splatinsert = insertelement <vscale x 8 x i64> undef, i64 %wide.trip.count, i32 0
+  %wide.end.idx.splat = shufflevector <vscale x 8 x i64> %wide.end.idx.splatinsert, <vscale x 8 x i64> undef, <vscale x 8 x i32> zeroinitializer
+  %predicate.entry = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 %wide.trip.count)
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %min.iters.checked
+  %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ]
+  %predicate = phi <vscale x 8 x i1> [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ]
+  %vec.phi = phi <vscale x 8 x i32> [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ]
+  %0 = getelementptr inbounds i16, ptr %a, i64 %index
+  %wide.masked.load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %0, i32 2, <vscale x 8 x i1> %predicate, <vscale x 8 x i16> undef)
+  %1 = sext <vscale x 8 x i16> %wide.masked.load to <vscale x 8 x i32>
+  %2 = getelementptr inbounds i16, ptr %b, i64 %index
+  %wide.masked.load19 = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %2, i32 2, <vscale x 8 x i1> %predicate, <vscale x 8 x i16> undef)
+  %3 = sext <vscale x 8 x i16> %wide.masked.load19 to <vscale x 8 x i32>
+  %4 = mul nsw <vscale x 8 x i32> %3, %1
+  %5 = select <vscale x 8 x i1> %predicate, <vscale x 8 x i32> %4, <vscale x 8 x i32> zeroinitializer
+  %6 = add nsw <vscale x 8 x i32> %vec.phi, %5
+  %vs = call i64 @llvm.vscale.i64()
+  %vs.scaled = mul i64 %vs, 8
+  %index.next = add nuw i64 %index, %vs.scaled
+  %.splatinsert = insertelement <vscale x 8 x i64> undef, i64 %index.next, i32 0
+  %.splat = shufflevector <vscale x 8 x i64> %.splatinsert, <vscale x 8 x i64> undef, <vscale x 8 x i32> zeroinitializer
+  %predicate.next = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count)
+  %7 = extractelement <vscale x 8 x i1> %predicate.next, i64 0
+  br i1 %7, label %vector.body, label %middle.block
+
+middle.block:                                     ; preds = %vector.body
+  %8 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %6)
+  %phitmp20 = lshr i32 %8, 16
+  %phitmp = trunc i32 %phitmp20 to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %middle.block, %entry
+  %acc.0.lcssa = phi i16 [ 0, %entry ], [ %phitmp, %middle.block ]
+  ret i16 %acc.0.lcssa
+}
+
+define dso_local i16 @sve_sdot_loop_i16_to_i32_interleavedx2_scalartail(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define dso_local i16 @sve_sdot_loop_i16_to_i32_interleavedx2_scalartail
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP9:%.*]] = icmp sgt i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP9]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 4
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ugt i64 [[TMP1]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[FOR_BODY_PREHEADER17:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub nuw nsw i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 3
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE1:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 8 x i16>, ptr [[TMP5]], align 2
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i64 [[TMP4]]
+; CHECK-NEXT:    [[WIDE_LOAD14:%.*]] = load <vscale x 8 x i16>, ptr [[TMP6]], align 2
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD15:%.*]] = load <vscale x 8 x i16>, ptr [[TMP7]], align 2
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i16, ptr [[TMP7]], i64 [[TMP4]]
+; CHECK-NEXT:    [[WIDE_LOAD16:%.*]] = load <vscale x 8 x i16>, ptr [[TMP8]], align 2
+; CHECK-NEXT:    [[TMP9]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE1]], <vscale x 8 x i16> [[WIDE_LOAD15]], <vscale x 8 x i16> [[WIDE_LOAD]])
+; CHECK-NEXT:    [[TMP10]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE]], <vscale x 8 x i16> [[WIDE_LOAD16]], <vscale x 8 x i16> [[WIDE_LOAD14]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP1]]
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP12:%.*]] = add <vscale x 2 x i64> [[TMP10]], [[TMP9]]
+; CHECK-NEXT:    [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP12]])
+; CHECK-NEXT:    [[DOT_TRUNC:%.*]] = trunc i64 [[TMP13]] to i32
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0
+; CHECK-NEXT:    [[EXTRACT4:%.*]] = lshr i64 [[TMP13]], 16
+; CHECK-NEXT:    [[EXTRACT_T:%.*]] = trunc i64 [[EXTRACT4]] to i16
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY_PREHEADER17]]
+; CHECK:       for.body.preheader17:
+; CHECK-NEXT:    [[INDVARS_IV_PH:%.*]] = phi i64 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[N_VEC]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    [[ACC_010_PH:%.*]] = phi i32 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[DOT_TRUNC]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    [[ADD_LCSSA_OFF16:%.*]] = phi i16 [ [[EXTRACT_T]], [[MIDDLE_BLOCK]] ], [ [[EXTRACT_T3:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[ADD_LCSSA_OFF16]], [[FOR_COND_CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT:    ret i16 [[ACC_0_LCSSA]]
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[INDVARS_IV_PH]], [[FOR_BODY_PREHEADER17]] ]
+; CHECK-NEXT:    [[ACC_010:%.*]] = phi i32 [ [[ADD:%.*]], [[FOR_BODY]] ], [ [[ACC_010_PH]], [[FOR_BODY_PREHEADER17]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP14:%.*]] = load i16, ptr [[ARRAYIDX]], align 2
+; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP14]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP15:%.*]] = load i16, ptr [[ARRAYIDX2]], align 2
+; CHECK-NEXT:    [[CONV3:%.*]] = sext i16 [[TMP15]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ADD]] = add nsw i32 [[MUL]], [[ACC_010]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT:    [[EXTRACT2:%.*]] = lshr i32 [[ADD]], 16
+; CHECK-NEXT:    [[EXTRACT_T3]] = trunc i32 [[EXTRACT2]] to i16
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]]
+;
+entry:
+  %cmp9 = icmp sgt i32 %N, 0
+  br i1 %cmp9, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:                               ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %0 = tail call i64 @llvm.vscale.i64()
+  %1 = shl nuw nsw i64 %0, 4
+  %min.iters.check = icmp ugt i64 %1, %wide.trip.count
+  br i1 %min.iters.check, label %for.body.preheader17, label %vector.ph
+
+vector.ph:                                        ; preds = %for.body.preheader
+  %n.mod.vf = urem i64 %wide.trip.count, %1
+  %n.vec = sub nuw nsw i64 %wide.trip.count, %n.mod.vf
+  %2 = tail call i32 @llvm.vscale.i32()
+  %3 = shl nuw nsw i32 %2, 3
+  %4 = zext i32 %3 to i64
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %vector.ph
+  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+  %vec.phi = phi <vscale x 8 x i32> [ zeroinitializer, %vector.ph ], [ %15, %vector.body ]
+  %vec.phi13 = phi <vscale x 8 x i32> [ zeroinitializer, %vector.ph ], [ %16, %vector.body ]
+  %5 = getelementptr inbounds i16, ptr %a, i64 %index
+  %wide.load = load <vscale x 8 x i16>, ptr %5, align 2
+  %6 = getelementptr inbounds i16, ptr %5, i64 %4
+  %wide.load14 = load <vscale x 8 x i16>, ptr %6, align 2
+  %7 = sext <vscale x 8 x i16> %wide.load to <vscale x 8 x i32>
+  %8 = sext <vscale x 8 x i16> %wide.load14 to <vscale x 8 x i32>
+  %9 = getelementptr inbounds i16, ptr %b, i64 %index
+  %wide.load15 = load <vscale x 8 x i16>, ptr %9, align 2
+  %10 = getelementptr inbounds i16, ptr %9, i64 %4
+  %wide.load16 = load <vscale x 8 x i16>, ptr %10, align 2
+  %11 = sext <vscale x 8 x i16> %wide.load15 to <vscale x 8 x i32>
+  %12 = sext <vscale x 8 x i16> %wide.load16 to <vscale x 8 x i32>
+  %13 = mul nsw <vscale x 8 x i32> %11, %7
+  %14 = mul nsw <vscale x 8 x i32> %12, %8
+  %15 = add <vscale x 8 x i32> %13, %vec.phi
+  %16 = add <vscale x 8 x i32> %14, %vec.phi13
+  %index.next = add nuw i64 %index, %1
+  %17 = icmp eq i64 %index.next, %n.vec
+  br i1 %17, label %middle.block, label %vector.body
+
+middle.block:                                     ; preds = %vector.body
+  %bin.rdx = add <vscale x 8 x i32> %16, %15
+  %18 = tail call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %bin.rdx)
+  %cmp.n = icmp eq i64 %n.mod.vf, 0
+  br i1 %cmp.n, label %for.cond.cleanup.loopexit, label %for.body.preheader17
+
+for.body.preheader17:                             ; preds = %for.body.preheader, %middle.block
+  %indvars.iv.ph = phi i64 [ 0, %for.body.preheader ], [ %n.vec, %middle.block ]
+  %acc.010.ph = phi i32 [ 0, %for.body.preheader ], [ %18, %middle.block ]
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body, %middle.block
+  %add.lcssa = phi i32 [ %18, %middle.block ], [ %add, %for.body ]
+  %19 = lshr i32 %add.lcssa, 16
+  %20 = trunc i32 %19 to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit, %entry
+  %acc.0.lcssa = phi i16 [ 0, %entry ], [ %20, %for.cond.cleanup.loopexit ]
+  ret i16 %acc.0.lcssa
+
+for.body:                                         ; preds = %for.body.preheader17, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ %indvars.iv.ph, %for.body.preheader17 ]
+  %acc.010 = phi i32 [ %add, %for.body ], [ %acc.010.ph, %for.body.preheader17 ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %21 = load i16, ptr %arrayidx, align 2
+  %conv = sext i16 %21 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %22 = load i16, ptr %arrayidx2, align 2
+  %conv3 = sext i16 %22 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.010
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
+}
+
+define i16 @sve_udot_loop_i16_to_i32(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define i16 @sve_udot_loop_i16_to_i32
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP11_NOT:%.*]] = icmp eq i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP11_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[MIN_ITERS_CHECKED:%.*]]
+; CHECK:       min.iters.checked:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[PREDICATE_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PREDICATE:%.*]] = phi <vscale x 8 x i1> [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP0]], i32 2, <vscale x 8 x i1> [[PREDICATE]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD19:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP1]], i32 2, <vscale x 8 x i1> [[PREDICATE]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP2]] = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD19]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD]])
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VS_SCALED:%.*]] = shl i64 [[VS]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]]
+; CHECK-NEXT:    [[PREDICATE_NEXT]] = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 8 x i1> [[PREDICATE_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP2]])
+; CHECK-NEXT:    [[PHITMP201:%.*]] = lshr i64 [[TMP4]], 16
+; CHECK-NEXT:    [[PHITMP:%.*]] = trunc i64 [[PHITMP201]] to i16
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret i16 [[ACC_0_LCSSA]]
+;
+entry:
+  %cmp11 = icmp ugt i32 %N, 0
+  br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup
+
+min.iters.checked:                                ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %wide.end.idx.splatinsert = insertelement <vscale x 8 x i64> undef, i64 %wide.trip.count, i32 0
+  %wide.end.idx.splat = shufflevector <vscale x 8 x i64> %wide.end.idx.splatinsert, <vscale x 8 x i64> undef, <vscale x 8 x i32> zeroinitializer
+  %predicate.entry = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 0, i64 %wide.trip.count)
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %min.iters.checked
+  %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ]
+  %predicate = phi <vscale x 8 x i1> [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ]
+  %vec.phi = phi <vscale x 8 x i32> [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ]
+  %0 = getelementptr inbounds i16, ptr %a, i64 %index
+  %wide.masked.load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %0, i32 2, <vscale x 8 x i1> %predicate, <vscale x 8 x i16> undef)
+  %1 = zext <vscale x 8 x i16> %wide.masked.load to <vscale x 8 x i32>
+  %2 = getelementptr inbounds i16, ptr %b, i64 %index
+  %wide.masked.load19 = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %2, i32 2, <vscale x 8 x i1> %predicate, <vscale x 8 x i16> undef)
+  %3 = zext <vscale x 8 x i16> %wide.masked.load19 to <vscale x 8 x i32>
+  %4 = mul nsw <vscale x 8 x i32> %3, %1
+  %5 = select <vscale x 8 x i1> %predicate, <vscale x 8 x i32> %4, <vscale x 8 x i32> zeroinitializer
+  %6 = add nsw <vscale x 8 x i32> %vec.phi, %5
+  %vs = call i64 @llvm.vscale.i64()
+  %vs.scaled = mul i64 %vs, 8
+  %index.next = add nuw i64 %index, %vs.scaled
+  %.splatinsert = insertelement <vscale x 8 x i64> undef, i64 %index.next, i32 0
+  %.splat = shufflevector <vscale x 8 x i64> %.splatinsert, <vscale x 8 x i64> undef, <vscale x 8 x i32> zeroinitializer
+  %predicate.next = call <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count)
+  %7 = extractelement <vscale x 8 x i1> %predicate.next, i64 0
+  br i1 %7, label %vector.body, label %middle.block
+
+middle.block:                                     ; preds = %vector.body
+  %8 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %6)
+  %phitmp20 = lshr i32 %8, 16
+  %phitmp = trunc i32 %phitmp20 to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %middle.block, %entry
+  %acc.0.lcssa = phi i16 [ 0, %entry ], [ %phitmp, %middle.block ]
+  ret i16 %acc.0.lcssa
+}
+
+define dso_local i16 @sve_udot_loop_i16_to_i32_interleavedx4_foldedtail(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define dso_local i16 @sve_udot_loop_i16_to_i32_interleavedx4_foldedtail
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP9:%.*]] = icmp sgt i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP9]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 4
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw nsw i64 [[TMP0]], 24
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY16:%.*]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP2]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY15:%.*]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP1]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY17:%.*]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP3]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.vscale.i32()
+; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 3
+; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = shl nuw nsw i32 [[TMP4]], 4
+; CHECK-NEXT:    [[TMP8:%.*]] = zext i32 [[TMP7]] to i64
+; CHECK-NEXT:    [[TMP9:%.*]] = mul nuw nsw i32 [[TMP4]], 24
+; CHECK-NEXT:    [[TMP10:%.*]] = zext i32 [[TMP9]] to i64
+; CHECK-NEXT:    [[TMP11:%.*]] = shl nuw nsw i64 [[TMP0]], 5
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK18:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY15]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT31:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK19:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY16]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT32:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK20:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY17]], [[FOR_BODY_PREHEADER]] ], [ [[ACTIVE_LANE_MASK_NEXT33:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE3:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP20:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE2:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP21:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE1:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[FOR_BODY_PREHEADER]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP12]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP6]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD24:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP13]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK18]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP8]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD25:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP14]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK19]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i16, ptr [[TMP12]], i64 [[TMP10]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD26:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP15]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK20]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD27:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP16]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP6]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD28:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP17]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK18]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP8]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD29:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP18]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK19]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr inbounds i16, ptr [[TMP16]], i64 [[TMP10]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD30:%.*]] = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull [[TMP19]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK20]], <vscale x 8 x i16> zeroinitializer)
+; CHECK-NEXT:    [[TMP20]] = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE3]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD27]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD]])
+; CHECK-NEXT:    [[TMP21]] = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE2]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD28]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD24]])
+; CHECK-NEXT:    [[TMP22]] = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE1]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD29]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD25]])
+; CHECK-NEXT:    [[TMP23]] = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> [[DOT_ACCUMULATE]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD30]], <vscale x 8 x i16> [[WIDE_MASKED_LOAD26]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[TMP24:%.*]] = add i64 [[INDEX_NEXT]], [[TMP1]]
+; CHECK-NEXT:    [[TMP25:%.*]] = add i64 [[INDEX_NEXT]], [[TMP2]]
+; CHECK-NEXT:    [[TMP26:%.*]] = add i64 [[INDEX_NEXT]], [[TMP3]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT31]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP24]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT32]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP25]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT33]] = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[TMP26]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP27:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP27]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP28:%.*]] = add <vscale x 2 x i64> [[TMP23]], [[TMP22]]
+; CHECK-NEXT:    [[TMP29:%.*]] = add <vscale x 2 x i64> [[TMP21]], [[TMP20]]
+; CHECK-NEXT:    [[TMP30:%.*]] = add <vscale x 2 x i64> [[TMP28]], [[TMP29]]
+; CHECK-NEXT:    [[TMP31:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP30]])
+; CHECK-NEXT:    [[TMP32:%.*]] = lshr i64 [[TMP31]], 16
+; CHECK-NEXT:    [[TMP33:%.*]] = trunc i64 [[TMP32]] to i16
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TMP33]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret i16 [[ACC_0_LCSSA]]
+;
+entry:
+  %cmp9 = icmp sgt i32 %N, 0
+  br i1 %cmp9, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:                               ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %0 = tail call i64 @llvm.vscale.i64()
+  %1 = shl nuw nsw i64 %0, 3
+  %2 = shl nuw nsw i64 %0, 4
+  %3 = mul nuw nsw i64 %0, 24
+  %active.lane.mask.entry = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %wide.trip.count)
+  %active.lane.mask.entry16 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %2, i64 %wide.trip.count)
+  %active.lane.mask.entry15 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %1, i64 %wide.trip.count)
+  %active.lane.mask.entry17 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %3, i64 %wide.trip.count)
+  %4 = tail call i32 @llvm.vscale.i32()
+  %5 = shl nuw nsw i32 %4, 3
+  %6 = zext i32 %5 to i64
+  %7 = shl nuw nsw i32 %4, 4
+  %8 = zext i32 %7 to i64
+  %9 = mul nuw nsw i32 %4, 24
+  %10 = zext i32 %9 to i64
+  %11 = shl nuw nsw i64 %0, 5
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %for.body.preheader
+  %index = phi i64 [ 0, %for.body.preheader ], [ %index.next, %vector.body ]
+  %active.lane.mask = phi <vscale x 8 x i1> [ %active.lane.mask.entry, %for.body.preheader ], [ %active.lane.mask.next, %vector.body ]
+  %active.lane.mask18 = phi <vscale x 8 x i1> [ %active.lane.mask.entry15, %for.body.preheader ], [ %active.lane.mask.next31, %vector.body ]
+  %active.lane.mask19 = phi <vscale x 8 x i1> [ %active.lane.mask.entry16, %for.body.preheader ], [ %active.lane.mask.next32, %vector.body ]
+  %active.lane.mask20 = phi <vscale x 8 x i1> [ %active.lane.mask.entry17, %for.body.preheader ], [ %active.lane.mask.next33, %vector.body ]
+  %vec.phi = phi <vscale x 8 x i32> [ zeroinitializer, %for.body.preheader ], [ %33, %vector.body ]
+  %vec.phi21 = phi <vscale x 8 x i32> [ zeroinitializer, %for.body.preheader ], [ %35, %vector.body ]
+  %vec.phi22 = phi <vscale x 8 x i32> [ zeroinitializer, %for.body.preheader ], [ %37, %vector.body ]
+  %vec.phi23 = phi <vscale x 8 x i32> [ zeroinitializer, %for.body.preheader ], [ %39, %vector.body ]
+  %12 = getelementptr inbounds i16, ptr %a, i64 %index
+  %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %12, i32 2, <vscale x 8 x i1> %active.lane.mask, <vscale x 8 x i16> poison)
+  %13 = getelementptr inbounds i16, ptr %12, i64 %6
+  %wide.masked.load24 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %13, i32 2, <vscale x 8 x i1> %active.lane.mask18, <vscale x 8 x i16> poison)
+  %14 = getelementptr inbounds i16, ptr %12, i64 %8
+  %wide.masked.load25 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %14, i32 2, <vscale x 8 x i1> %active.lane.mask19, <vscale x 8 x i16> poison)
+  %15 = getelementptr inbounds i16, ptr %12, i64 %10
+  %wide.masked.load26 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %15, i32 2, <vscale x 8 x i1> %active.lane.mask20, <vscale x 8 x i16> poison)
+  %16 = zext <vscale x 8 x i16> %wide.masked.load to <vscale x 8 x i32>
+  %17 = zext <vscale x 8 x i16> %wide.masked.load24 to <vscale x 8 x i32>
+  %18 = zext <vscale x 8 x i16> %wide.masked.load25 to <vscale x 8 x i32>
+  %19 = zext <vscale x 8 x i16> %wide.masked.load26 to <vscale x 8 x i32>
+  %20 = getelementptr inbounds i16, ptr %b, i64 %index
+  %wide.masked.load27 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %20, i32 2, <vscale x 8 x i1> %active.lane.mask, <vscale x 8 x i16> poison)
+  %21 = getelementptr inbounds i16, ptr %20, i64 %6
+  %wide.masked.load28 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %21, i32 2, <vscale x 8 x i1> %active.lane.mask18, <vscale x 8 x i16> poison)
+  %22 = getelementptr inbounds i16, ptr %20, i64 %8
+  %wide.masked.load29 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %22, i32 2, <vscale x 8 x i1> %active.lane.mask19, <vscale x 8 x i16> poison)
+  %23 = getelementptr inbounds i16, ptr %20, i64 %10
+  %wide.masked.load30 = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr nonnull %23, i32 2, <vscale x 8 x i1> %active.lane.mask20, <vscale x 8 x i16> poison)
+  %24 = zext <vscale x 8 x i16> %wide.masked.load27 to <vscale x 8 x i32>
+  %25 = zext <vscale x 8 x i16> %wide.masked.load28 to <vscale x 8 x i32>
+  %26 = zext <vscale x 8 x i16> %wide.masked.load29 to <vscale x 8 x i32>
+  %27 = zext <vscale x 8 x i16> %wide.masked.load30 to <vscale x 8 x i32>
+  %28 = mul nuw nsw <vscale x 8 x i32> %24, %16
+  %29 = mul nuw nsw <vscale x 8 x i32> %25, %17
+  %30 = mul nuw nsw <vscale x 8 x i32> %26, %18
+  %31 = mul nuw nsw <vscale x 8 x i32> %27, %19
+  %32 = select <vscale x 8 x i1> %active.lane.mask, <vscale x 8 x i32> %28, <vscale x 8 x i32> zeroinitializer
+  %33 = add <vscale x 8 x i32> %vec.phi, %32
+  %34 = select <vscale x 8 x i1> %active.lane.mask18, <vscale x 8 x i32> %29, <vscale x 8 x i32> zeroinitializer
+  %35 = add <vscale x 8 x i32> %vec.phi21, %34
+  %36 = select <vscale x 8 x i1> %active.lane.mask19, <vscale x 8 x i32> %30, <vscale x 8 x i32> zeroinitializer
+  %37 = add <vscale x 8 x i32> %vec.phi22, %36
+  %38 = select <vscale x 8 x i1> %active.lane.mask20, <vscale x 8 x i32> %31, <vscale x 8 x i32> zeroinitializer
+  %39 = add <vscale x 8 x i32> %vec.phi23, %38
+  %index.next = add i64 %index, %11
+  %40 = add i64 %index.next, %1
+  %41 = add i64 %index.next, %2
+  %42 = add i64 %index.next, %3
+  %active.lane.mask.next = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %index.next, i64 %wide.trip.count)
+  %active.lane.mask.next31 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %40, i64 %wide.trip.count)
+  %active.lane.mask.next32 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %41, i64 %wide.trip.count)
+  %active.lane.mask.next33 = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %42, i64 %wide.trip.count)
+  %43 = extractelement <vscale x 8 x i1> %active.lane.mask.next, i64 0
+  br i1 %43, label %vector.body, label %middle.block
+
+middle.block:                                     ; preds = %vector.body
+  %bin.rdx = add <vscale x 8 x i32> %35, %33
+  %bin.rdx34 = add <vscale x 8 x i32> %37, %bin.rdx
+  %bin.rdx35 = add <vscale x 8 x i32> %39, %bin.rdx34
+  %44 = tail call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %bin.rdx35)
+  %45 = lshr i32 %44, 16
+  %46 = trunc i32 %45 to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %middle.block, %entry
+  %acc.0.lcssa = phi i16 [ 0, %entry ], [ %46, %middle.block ]
+  ret i16 %acc.0.lcssa
+}
+
+define i8 @sve_sdot_loop_i8_to_i16(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define i8 @sve_sdot_loop_i8_to_i16
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       min.iters.checked:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[PREDICATE_ENTRY:%.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PREDICATE:%.*]] = phi <vscale x 16 x i1> [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP0]], i32 1, <vscale x 16 x i1> [[PREDICATE]], <vscale x 16 x i8> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD19:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP1]], i32 1, <vscale x 16 x i1> [[PREDICATE]], <vscale x 16 x i8> zeroinitializer)
+; CHECK-NEXT:    [[TMP2]] = call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32> [[DOT_ACCUMULATE]], <vscale x 16 x i8> [[WIDE_MASKED_LOAD19]], <vscale x 16 x i8> [[WIDE_MASKED_LOAD]])
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VS_SCALED:%.*]] = shl i64 [[VS]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]]
+; CHECK-NEXT:    [[PREDICATE_NEXT]] = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 16 x i1> [[PREDICATE_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP2]])
+; CHECK-NEXT:    [[PHITMP201:%.*]] = lshr i32 [[TMP4]], 8
+; CHECK-NEXT:    [[PHITMP:%.*]] = trunc i32 [[PHITMP201]] to i8
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i8 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret i8 [[ACC_0_LCSSA]]
+;
+entry:
+  %cmp11 = icmp sgt i32 %N, 0
+  br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup
+
+min.iters.checked:                                ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %wide.end.idx.splatinsert = insertelement <vscale x 16 x i64> undef, i64 %wide.trip.count, i32 0
+  %wide.end.idx.splat = shufflevector <vscale x 16 x i64> %wide.end.idx.splatinsert, <vscale x 16 x i64> undef, <vscale x 16 x i32> zeroinitializer
+  %predicate.entry = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 %wide.trip.count)
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %min.iters.checked
+  %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ]
+  %predicate = phi <vscale x 16 x i1> [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ]
+  %vec.phi = phi <vscale x 16 x i16> [ zeroinitializer, %min.iters.checked ], [ %6, %vector.body ]
+  %0 = getelementptr inbounds i8, ptr %a, i64 %index
+  %wide.masked.load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %0, i32 1, <vscale x 16 x i1> %predicate, <vscale x 16 x i8> undef)
+  %1 = sext <vscale x 16 x i8> %wide.masked.load to <vscale x 16 x i16>
+  %2 = getelementptr inbounds i8, i8* %b, i64 %index
+  %wide.masked.load19 = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %2, i32 1, <vscale x 16 x i1> %predicate, <vscale x 16 x i8> undef)
+  %3 = sext <vscale x 16 x i8> %wide.masked.load19 to <vscale x 16 x i16>
+  %4 = mul nsw <vscale x 16 x i16> %3, %1
+  %5 = select <vscale x 16 x i1> %predicate, <vscale x 16 x i16> %4, <vscale x 16 x i16> zeroinitializer
+  %6 = add nsw <vscale x 16 x i16> %vec.phi, %5
+  %vs = call i64 @llvm.vscale.i64()
+  %vs.scaled = mul i64 %vs, 16
+  %index.next = add nuw i64 %index, %vs.scaled
+  %.splatinsert = insertelement <vscale x 16 x i64> undef, i64 %index.next, i32 0
+  %.splat = shufflevector <vscale x 16 x i64> %.splatinsert, <vscale x 16 x i64> undef, <vscale x 16 x i32> zeroinitializer
+  %predicate.next = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 %index.next, i64 %wide.trip.count)
+  %7 = extractelement <vscale x 16 x i1> %predicate.next, i64 0
+  br i1 %7, label %vector.body, label %middle.block
+
+middle.block:                                     ; preds = %vector.body
+  %8 = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> %6)
+  %phitmp20 = lshr i16 %8, 8
+  %phitmp = trunc i16 %phitmp20 to i8
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %middle.block, %entry
+  %acc.0.lcssa = phi i8 [ 0, %entry ], [ %phitmp, %middle.block ]
+  ret i8 %acc.0.lcssa
+}
+
+define i8 @sve_udot_loop_i8_to_i16(ptr readonly %a, ptr readonly %b, i32 %N) #0 {
+; CHECK-LABEL: define i8 @sve_udot_loop_i8_to_i16
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP11:%.*]] = icmp sgt i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP11]], label [[MIN_ITERS_CHECKED:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       min.iters.checked:
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[N]] to i64
+; CHECK-NEXT:    [[PREDICATE_ENTRY:%.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[MIN_ITERS_CHECKED]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[PREDICATE:%.*]] = phi <vscale x 16 x i1> [ [[PREDICATE_ENTRY]], [[MIN_ITERS_CHECKED]] ], [ [[PREDICATE_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[DOT_ACCUMULATE:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[MIN_ITERS_CHECKED]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP0]], i32 1, <vscale x 16 x i1> [[PREDICATE]], <vscale x 16 x i8> undef)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD19:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP1]], i32 1, <vscale x 16 x i1> [[PREDICATE]], <vscale x 16 x i8> undef)
+; CHECK-NEXT:    [[TMP2]] = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> [[DOT_ACCUMULATE]], <vscale x 16 x i8> [[WIDE_MASKED_LOAD19]], <vscale x 16 x i8> [[WIDE_MASKED_LOAD]])
+; CHECK-NEXT:    [[VS:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[VS_SCALED:%.*]] = shl i64 [[VS]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VS_SCALED]]
+; CHECK-NEXT:    [[PREDICATE_NEXT]] = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 [[INDEX_NEXT]], i64 [[WIDE_TRIP_COUNT]])
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 16 x i1> [[PREDICATE_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP3]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP2]])
+; CHECK-NEXT:    [[PHITMP201:%.*]] = lshr i32 [[TMP4]], 8
+; CHECK-NEXT:    [[PHITMP:%.*]] = trunc i32 [[PHITMP201]] to i8
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    [[ACC_0_LCSSA:%.*]] = phi i8 [ 0, [[ENTRY:%.*]] ], [ [[PHITMP]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT:    ret i8 [[ACC_0_LCSSA]]
+;
+entry:
+  %cmp11 = icmp sgt i32 %N, 0
+  br i1 %cmp11, label %min.iters.checked, label %for.cond.cleanup
+
+min.iters.checked:                                ; preds = %entry
+  %wide.trip.count = zext i32 %N to i64
+  %wide.end.idx.splatinsert = insertelement <vscale x 16 x i64> undef, i64 %wide.trip.count, i32 0
+  %wide.end.idx.splat = shufflevector <vscale x 16 x i64> %wide.end.idx.splatinsert, <vscale x 16 x i64> undef, <vscale x 16 x i32> zeroinitializer
+  %predicate.entry = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 0, i64 %wide.trip.count)
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %min.iters.checked
+  %index = phi i64 [ 0, %min.iters.checked ], [ %index.next, %vector.body ]
+  %predicate = phi <vscale x 16 x i1> [ %predicate.entry, %min.iters.checked ], [ %predicate.next, %vector.body ]
+  %vec.phi = phi <vscale x 16 x i16> [ zeroinitializer, %min.iters.checked ], [ %5, %vector.body ]
+  %0 = getelementptr inbounds i8, ptr %a, i64 %index
+  %wide.masked.load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %0, i32 1, <vscale x 16 x i1> %predicate, <vscale x 16 x i8> undef)
+  %1 = zext <vscale x 16 x i8> %wide.masked.load to <vscale x 16 x i16>
+  %2 = getelementptr inbounds i8, i8* %b, i64 %index
+  %wide.masked.load19 = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr %2, i32 1, <vscale x 16 x i1> %predicate, <vscale x 16 x i8> undef)
+  %3 = zext <vscale x 16 x i8> %wide.masked.load19 to <vscale x 16 x i16>
+  %4 = mul nsw <vscale x 16 x i16> %3, %1
+  %5 = add nsw <vscale x 16 x i16> %vec.phi, %4
+  %vs = call i64 @llvm.vscale.i64()
+  %vs.scaled = mul i64 %vs, 16
+  %index.next = add nuw i64 %index, %vs.scaled
+  %.splatinsert = insertelement <vscale x 16 x i64> undef, i64 %index.next, i32 0
+  %.splat = shufflevector <vscale x 16 x i64> %.splatinsert, <vscale x 16 x i64> undef, <vscale x 16 x i32> zeroinitializer
+  %predicate.next = call <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64 %index.next, i64 %wide.trip.count)
+  %6 = extractelement <vscale x 16 x i1> %predicate.next, i64 0
+  br i1 %6, label %vector.body, label %middle.block
+
+middle.block:                                     ; preds = %vector.body
+  %7 = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> %5)
+  %phitmp20 = lshr i16 %7, 8
+  %phitmp = trunc i16 %phitmp20 to i8
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %middle.block, %entry
+  %acc.0.lcssa = phi i8 [ 0, %entry ], [ %phitmp, %middle.block ]
+  ret i8 %acc.0.lcssa
+}
+
+define i64 @sve_sdot_i16_to_i64(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: define i64 @sve_sdot_i16_to_i64
+; CHECK-SAME: (<vscale x 8 x i16> [[A:%.*]], <vscale x 8 x i16> [[B:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> [[A]], <vscale x 8 x i16> [[B]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP0]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+entry:
+  %exta = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %extb = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mul = mul nsw <vscale x 8 x i64> %exta, %extb
+  %acc = call i64 @llvm.vector.reduce.add.nxv8i64(<vscale x 8 x i64> %mul)
+  ret i64 %acc
+}
+
+define i32 @sve_udot_i8_to_i32(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: define i32 @sve_udot_i8_to_i32
+; CHECK-SAME: (<vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> [[A]], <vscale x 16 x i8> [[B]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP0]])
+; CHECK-NEXT:    ret i32 [[TMP1]]
+;
+entry:
+  %exta = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %extb = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mul = mul nsw <vscale x 16 x i32> %exta, %extb
+  %acc = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %mul)
+  ret i32 %acc
+}
+
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
+declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
+declare i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32>)
+declare i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16>)
+declare i64 @llvm.vector.reduce.add.nxv8i64(<vscale x 8 x i64>)
+declare i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32>)
+declare i64 @llvm.vscale.i64()
+declare i32 @llvm.vscale.i32()
+declare <vscale x 8 x i1> @llvm.aarch64.sve.whilelo.nxv8i1.i64(i64, i64)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.whilelo.nxv16i1.i64(i64, i64)
+declare <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64, i64)
+
+attributes #0 = { "target-features"="+sve" }



More information about the llvm-commits mailing list