[llvm] [LV][SVE] Recognize potential DOT sequences and use a wider VF (PR #69587)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 19 03:47:18 PDT 2023


https://github.com/huntergr-arm created https://github.com/llvm/llvm-project/pull/69587

This patch extends the LoopVectorize cost model to identify when
a extend->multiply->accumulate chain is suitable for the UDOT/SDOT
instructions in AArch64 (SVE in particular) and will ignore the
extension when determining desirable VFs.


>From 02cbd16fef027fcd4e505b8c988f41a194d9e46c Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 17 Oct 2023 16:28:52 +0100
Subject: [PATCH] [LV][SVE] Recognize potential DOT sequences and use a wider
 VF

This patch extends the LoopVectorize cost model to identify when
a extend->multiply->accumulate chain is suitable for the UDOT/SDOT
instructions in AArch64 (SVE in particular) and will ignore the
extension when determining desirable VFs.
---
 .../llvm/Analysis/TargetTransformInfo.h       |   8 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +
 .../AArch64/AArch64TargetTransformInfo.h      |  11 +
 .../Transforms/Vectorize/LoopVectorize.cpp    |  38 ++
 .../AArch64/maximize-bandwidth-for-dot.ll     | 485 ++++++++++++++++++
 6 files changed, 548 insertions(+)
 create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..b11c325f31c5ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -772,6 +772,10 @@ class TargetTransformInfo {
   /// Return true if the target supports masked expand load.
   bool isLegalMaskedExpandLoad(Type *DataType) const;
 
+  /// Returns true if the types are legal for DOT product instructions on
+  /// the target (extend->multiply->accumulate)
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const;
+
   /// Return true if this is an alternating opcode pattern that can be lowered
   /// to a single instruction on the target. In X86 this is for the addsub
   /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1787,6 +1791,7 @@ class TargetTransformInfo::Concept {
                                            Align Alignment) = 0;
   virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
   virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalDotProd(Type *DataType, Type *ExtType) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
                                const SmallBitVector &OpcodeMask) const = 0;
@@ -2267,6 +2272,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isLegalMaskedExpandLoad(Type *DataType) override {
     return Impl.isLegalMaskedExpandLoad(DataType);
   }
+  bool isLegalDotProd(Type *DataType, Type *ExtType) override {
+    return Impl.isLegalDotProd(DataType, ExtType);
+  }
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const override {
     return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..01f5af17a6f4814 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -302,6 +302,8 @@ class TargetTransformInfoImplBase {
 
   bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
 
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const { return false; }
+
   bool enableOrderedReductions() const { return false; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..fbbf8c3f5e34217 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,6 +492,10 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
   return TTIImpl->isLegalMaskedExpandLoad(DataType);
 }
 
+bool TargetTransformInfo::isLegalDotProd(Type *DataType, Type *ExtType) const {
+  return TTIImpl->isLegalDotProd(DataType, ExtType);
+}
+
 bool TargetTransformInfo::enableOrderedReductions() const {
   return TTIImpl->enableOrderedReductions();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..6be8f2867ec1a7f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -151,6 +151,17 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  // TODO: NEON should be able to support this after... 8.3 or so?
+  // Need to make sure that the input type is either i8 or i16, and that
+  // the extended type is at most the accumulator type of the dot product
+  // instructions so that we don't lose data.
+  bool isLegalDotProd(Type *DataType, Type *ExtType) const {
+    return ST->hasSVE() && ((DataType->isIntegerTy(8) &&
+                             ExtType->getPrimitiveSizeInBits() <= 32) ||
+                            (DataType->isIntegerTy(16) &&
+                             ExtType->getPrimitiveSizeInBits() <= 64));
+  }
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index aa435b0d47aa599..3b585cd221eda42 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -80,6 +80,7 @@
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
 #include "llvm/Analysis/LoopInfo.h"
@@ -1921,6 +1922,9 @@ class LoopVectorizationCostModel {
 
   /// All element types found in the loop.
   SmallPtrSet<Type *, 16> ElementTypesInLoop;
+
+  /// Extends used as part of a dot-product chain; these are 'free'.
+  SmallPtrSet<Value *, 2> DotExtends;
 };
 } // end namespace llvm
 
@@ -5580,6 +5584,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() {
 }
 
 void LoopVectorizationCostModel::collectElementTypesForWidening() {
+  using namespace llvm::PatternMatch;
   ElementTypesInLoop.clear();
   // For each block.
   for (BasicBlock *BB : TheLoop->blocks()) {
@@ -5607,6 +5612,34 @@ void LoopVectorizationCostModel::collectElementTypesForWidening() {
                                       RdxDesc.getRecurrenceType(),
                                       TargetTransformInfo::ReductionFlags()))
           continue;
+        // DOT Prod proto...
+        if (RdxDesc.getRecurrenceKind() == RecurKind::Add) {
+          Instruction *Sum = RdxDesc.getLoopExitInstr();
+          Value *Accum = Legal->getReductionVars().find(PN)->first;
+
+          if (!Accum->hasOneUse() || !Sum->hasNUses(2))
+            continue;
+
+          Value *Step = (Sum->getOperand(0) == Accum) ? Sum->getOperand(1)
+                                                      : Sum->getOperand(0);
+          Value *ValA = nullptr, *ValB = nullptr;
+
+          if (match(Step,
+                    m_OneUse(m_Mul(m_ZExtOrSExt(m_OneUse(m_Value(ValA))),
+                                   m_ZExtOrSExt(m_OneUse(m_Value(ValB)))))) &&
+              (ValA->getType() == ValB->getType()) &&
+              TTI.isLegalDotProd(ValA->getType(), Step->getType())) {
+            Instruction *I = cast<Instruction>(Step);
+
+            // Make sure the extends are only used by the multiply.
+            if (I->getOperand(0)->hasOneUser() &&
+                I->getOperand(1)->hasOneUser()) {
+              DotExtends.insert(I->getOperand(0));
+              DotExtends.insert(I->getOperand(1));
+              continue;
+            }
+          }
+        }
         T = RdxDesc.getRecurrenceType();
       }
 
@@ -7351,6 +7384,11 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
         CCH = ComputeCCH(Load);
     }
 
+    // Extensions used in dot product calculations are 'free', since the
+    // dot instruction performs that operation internally before multiplying
+    if (DotExtends.contains(I))
+      return 0;
+
     // We optimize the truncation of induction variables having constant
     // integer steps. The cost of these truncations is the same as the scalar
     // operation.
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll
new file mode 100644
index 000000000000000..2014bb18b11b104
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/maximize-bandwidth-for-dot.ll
@@ -0,0 +1,485 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt -passes=loop-vectorize,simplifycfg,instcombine -force-vector-interleave=1 -prefer-predicate-over-epilogue=predicate-dont-vectorize -S < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+;; For SVE, we want to make sure that we 'maximize bandwidth' during loop
+;; vectorization of IR patterns that roughly match an SDOT or UDOT instruction.
+;; Normally, <vscale x 8 x i32> wouldn't be considered since it takes up two
+;; registers, but since the *DOT instructions transform a given number of
+;; narrower input values into a smaller number of wider accumulation values, we
+;; won't actually use any additional registers for this case.
+;;
+;; This file just tests that the loop vectorizer sets up for the
+;; AArch64DotProdMatcher pass. To do so, it will need to identify extends
+;; that will be folded away by using the DOT instructions. For the first
+;; example below, the vectorized loop will use <vscale x 8 x i16> extended to
+;; <vscale x 8 x i32>. Normally the loop vectorizer would pick a VF of 4 since
+;; the i32 is the widest type, but since that will be folded away we want to
+;; pick a VF of 8 to maximize the number of i16s processed per iteration.
+;;
+;; The backend pass will then match this and plant a DOT intrinsic with
+;; 2 <vscale x 8 x i16>s as input and one <vscale x 2 x i64> as output.
+;;
+;; If the extend would exceed the capacity of the DOT instruction (basically
+;; if i8s were extended to i64s), then we can't perform the second part of
+;; the transformation. We then wouldn't want to perform the first part either.
+;; We also want to stop the transform if there was another use of one of the
+;; values in the chain that would be folded into the DOT instruction, since
+;; the intermediate values would never exist in a register for reuse.
+
+define i16 @sdot_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i16 @sdot_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP3]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP5]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 8 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i32> [[TMP7]], <vscale x 8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP13]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = sext i16 %0 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = sext i16 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 16
+  %phitmp14 = trunc i32 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+;; Similar to the above check, but for a zext instead of a sext.
+
+define i16 @udot_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i16 @udot_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP3]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <vscale x 8 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP5]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <vscale x 8 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nuw nsw <vscale x 8 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i32> [[TMP7]], <vscale x 8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP13]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = zext i16 %0 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = zext i16 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 16
+  %phitmp14 = trunc i32 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+;; In this case we don't want to use the maximum bandwidth since the accumulator
+;; type (i64) is wider than it would be in the sdot instruction for i8 inputs
+;; (i32).
+
+define i8 @sdot_xform_too_wide(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i8 @sdot_xform_too_wide
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 2 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 2 x i64> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8.p0(ptr [[TMP3]], i32 2, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i8> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 2 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 2 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 2 x i8> @llvm.masked.load.nxv2i8.p0(ptr [[TMP5]], i32 2, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i8> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 2 x i8> [[WIDE_MASKED_LOAD1]] to <vscale x 2 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 2 x i64> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i64> [[TMP7]], <vscale x 2 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 2 x i64> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i64 [[TMP13]], 56
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i64 [[PHITMP]] to i8
+; CHECK-NEXT:    ret i8 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i8, ptr %a, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx, align 2
+  %conv = sext i8 %0 to i64
+  %arrayidx2 = getelementptr inbounds i8, ptr %b, i64 %indvars.iv
+  %1 = load i8, i8* %arrayidx2, align 2
+  %conv3 = sext i8 %1 to i64
+  %mul = mul nsw i64 %conv3, %conv
+  %add = add nsw i64 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i64 %add, 56
+  %phitmp14 = trunc i64 %phitmp to i8
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i8 %phitmp14
+}
+
+;; In this case we don't want to use the maximum bandwidth because one of the
+;; values is used elsewhere in the loop, and would need to be calculated anyway
+;; instead of just being part of the udot instruction.
+
+define i16 @udot_xform_extra_use(ptr readonly %a, ptr readonly %b, ptr noalias %c, i64 %N) #0 {
+; CHECK-LABEL: define i16 @udot_xform_extra_use
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], ptr noalias [[C:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 2
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 4 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16.p0(ptr [[TMP3]], i32 2, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <vscale x 4 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16.p0(ptr [[TMP5]], i32 2, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <vscale x 4 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nuw nsw <vscale x 4 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[TMP7]], ptr [[TMP8]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP9:%.*]] = select <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i32> [[TMP7]], <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP10]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP12:%.*]] = shl i64 [[TMP11]], 2
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP12]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP13]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP5:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP10]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP14]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = zext i16 %0 to i32
+  %arrayidx2 = getelementptr inbounds i16, ptr %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = zext i16 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %outidx = getelementptr inbounds i32, ptr %c, i64 %indvars.iv
+  store i32 %mul, i32* %outidx, align 4
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 16
+  %phitmp14 = trunc i32 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+;; Similar to the first successful example, using i8 -> i16 instead
+
+define i8 @sdoti8toi16_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i8 @sdoti8toi16_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 16 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i16> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP3]], i32 2, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i8> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP5]], i32 2, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i8> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 16 x i8> [[WIDE_MASKED_LOAD1]] to <vscale x 16 x i16>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 16 x i16> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i16> [[TMP7]], <vscale x 16 x i16> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 16 x i16> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 16 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP6:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i16 [[TMP13]], 8
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i16 [[PHITMP]] to i8
+; CHECK-NEXT:    ret i8 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i16 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i8, ptr %a, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx, align 2
+  %conv = sext i8 %0 to i16
+  %arrayidx2 = getelementptr inbounds i8, ptr %b, i64 %indvars.iv
+  %1 = load i8, i8* %arrayidx2, align 2
+  %conv3 = sext i8 %1 to i16
+  %mul = mul nsw i16 %conv3, %conv
+  %add = add nsw i16 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i16 %add, 8
+  %phitmp14 = trunc i16 %phitmp to i8
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i8 %phitmp14
+}
+
+;; Now with i8 -> i32
+
+define i8 @sdoti8toi32_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i8 @sdoti8toi32_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 16 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 16 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP3]], i32 2, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i8> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8.p0(ptr [[TMP5]], i32 2, <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i8> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 16 x i8> [[WIDE_MASKED_LOAD1]] to <vscale x 16 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 16 x i32> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 16 x i1> [[ACTIVE_LANE_MASK]], <vscale x 16 x i32> [[TMP7]], <vscale x 16 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 16 x i32> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 16 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP7:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i32 [[TMP13]], 8
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i32 [[PHITMP]] to i8
+; CHECK-NEXT:    ret i8 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i8, ptr %a, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx, align 2
+  %conv = sext i8 %0 to i32
+  %arrayidx2 = getelementptr inbounds i8, ptr %b, i64 %indvars.iv
+  %1 = load i8, i8* %arrayidx2, align 2
+  %conv3 = sext i8 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %add = add nsw i32 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i32 %add, 8
+  %phitmp14 = trunc i32 %phitmp to i8
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i8 %phitmp14
+}
+
+;; And finally i16 -> i64
+
+define i16 @sdoti16toi64_xform_example(ptr readonly %a, ptr readonly %b, i64 %N) #0 {
+; CHECK-LABEL: define i16 @sdoti16toi64_xform_example
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 8 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i64> [ zeroinitializer, [[ENTRY]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP3]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD]] to <vscale x 8 x i64>
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr [[TMP5]], i32 2, <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i16> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <vscale x 8 x i16> [[WIDE_MASKED_LOAD1]] to <vscale x 8 x i64>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <vscale x 8 x i64> [[TMP6]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = select <vscale x 8 x i1> [[ACTIVE_LANE_MASK]], <vscale x 8 x i64> [[TMP7]], <vscale x 8 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP9]] = add <vscale x 8 x i64> [[VEC_PHI]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP10]], 3
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP11]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <vscale x 8 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP12]], label [[VECTOR_BODY]], label [[MIDDLE_BLOCK:%.*]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[TMP13:%.*]] = call i64 @llvm.vector.reduce.add.nxv8i64(<vscale x 8 x i64> [[TMP9]])
+; CHECK-NEXT:    [[PHITMP:%.*]] = lshr i64 [[TMP13]], 16
+; CHECK-NEXT:    [[PHITMP14:%.*]] = trunc i64 [[PHITMP]] to i16
+; CHECK-NEXT:    ret i16 [[PHITMP14]]
+;
+
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %acc.012 = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  %arrayidx = getelementptr inbounds i16, i16* %a, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = sext i16 %0 to i64
+  %arrayidx2 = getelementptr inbounds i16, i16* %b, i64 %indvars.iv
+  %1 = load i16, i16* %arrayidx2, align 2
+  %conv3 = sext i16 %1 to i64
+  %mul = mul nsw i64 %conv3, %conv
+  %add = add nsw i64 %mul, %acc.012
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %N
+  br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  %phitmp = lshr i64 %add, 16
+  %phitmp14 = trunc i64 %phitmp to i16
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit
+  ret i16 %phitmp14
+}
+
+attributes #0 = { "target-features"="+sve" }



More information about the llvm-commits mailing list