[llvm] [LV] Bundle sub reductions into VPExpressionRecipe (PR #147255)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 1 08:27:12 PDT 2025


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/147255

>From d64d52cfc5dc633ef322eb2338d275bb7dc00a03 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Mon, 30 Jun 2025 14:29:54 +0100
Subject: [PATCH 01/10] [LV] Bundle sub reductions into VPExpressionRecipe

This PR bundles sub reductions into the VPExpressionRecipe class and
adjusts the cost functions to take the negation into account.
---
 .../llvm/Analysis/TargetTransformInfo.h       |   4 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   3 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 +-
 .../AArch64/AArch64TargetTransformInfo.cpp    |   7 +-
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +-
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |   7 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |   1 +
 .../Transforms/Vectorize/LoopVectorize.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |  11 ++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  35 ++++-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  33 ++--
 .../Transforms/Vectorize/VectorCombine.cpp    |   4 +-
 .../vplan-printing-reductions.ll              | 143 ++++++++++++++++++
 14 files changed, 236 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c4ba8e9857dc4..9165f4aacfd2b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1651,8 +1651,10 @@ class TargetTransformInfo {
   /// extensions. This is the cost of as:
   /// ResTy vecreduce.add(mul (A, B)).
   /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
+  /// The multiply can optionally be negated, which signifies that it is a sub
+  /// reduction.
   LLVM_ABI InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 43813d2f3acb5..9d88d6a2c5bff 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -972,7 +972,7 @@ class TargetTransformInfoImplBase {
 
   virtual InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
-                         TTI::TargetCostKind CostKind) const {
+                         bool Negated, TTI::TargetCostKind CostKind) const {
     return 1;
   }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 0a10b51f97c63..4b6983fbd69a4 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3261,7 +3261,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
+                         bool Negated,
                          TTI::TargetCostKind CostKind) const override {
+    if (Negated)
+      return InstructionCost::getInvalid(CostKind);
     // Without any native support, this is equivalent to the cost of
     // vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
     // vecreduce.add(mul(A, B)).
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4ac8f03e6dbf5..eceecda244762 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1283,9 +1283,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
 }
 
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
-    bool IsUnsigned, Type *ResTy, VectorType *Ty,
+    bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
     TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
+  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
+                                         CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 490f6391c15a0..6ce01767d22be 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5487,8 +5487,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
 
 InstructionCost
 AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                       VectorType *VecTy,
+                                       VectorType *VecTy, bool Negated,
                                        TTI::TargetCostKind CostKind) const {
+  if (Negated)
+    return InstructionCost::getInvalid(CostKind);
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
@@ -5503,7 +5505,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return LT.first + 2;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
+  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
+                                       CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 42ae962b3b426..611593e248aef 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -460,7 +460,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
 
   InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
 
   InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 6b2854171c819..9821ffc4ffb29 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1917,8 +1917,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
 
 InstructionCost
 ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                   VectorType *ValTy,
+                                   VectorType *ValTy, bool Negated,
                                    TTI::TargetCostKind CostKind) const {
+  if (Negated)
+    return InstructionCost::getInvalid(CostKind);
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
@@ -1939,7 +1941,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
+  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
+                                       CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index cdd8bcb9f7416..5a5d6755500df 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -300,6 +300,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
+                         bool Negated,
                          TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1b1797ab30a35..babb0b8e22040 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5414,7 +5414,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     InstructionCost RedCost = TTI.getMulAccReductionCost(
-        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
 
     if (RedCost.isValid() &&
         RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5459,7 +5459,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
       InstructionCost ExtraExtCost = 0;
       if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
         Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5478,7 +5478,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
+          true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
 
       if (RedCost.isValid() && RedCost < MulCost + BaseCost)
         return I == RetI ? RedCost : 0;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 8ed26f23c859b..2c79b9344168e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2869,6 +2869,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
     /// vector operands, performing a reduction.add on the result, and adding
     /// the scalar result to a chain.
     MulAccReduction,
+    /// Represent an inloop multiply-accumulate reduction, multiplying the
+    /// extended vector operands, negating the multiplication, performing a
+    /// reduction.add
+    /// on the result, and adding
+    /// the scalar result to a chain.
+    ExtNegatedMulAccReduction,
   };
 
   /// Type of the expression.
@@ -2892,6 +2898,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
                      VPWidenRecipe *Mul, VPReductionRecipe *Red)
       : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
                            {Ext0, Ext1, Mul, Red}) {}
+  VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+                     VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+                     VPReductionRecipe *Red)
+      : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
+                           {Ext0, Ext1, Mul, Sub, Red}) {}
 
   ~VPExpressionRecipe() override {
     for (auto *R : reverse(ExpressionRecipes))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index bd9a93ed57b8a..88bff35d68680 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2814,13 +2814,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
         RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
   }
   case ExpressionTypes::MulAccReduction:
-    return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
+    return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+                                          Ctx.CostKind);
 
-  case ExpressionTypes::ExtMulAccReduction:
+  case ExpressionTypes::ExtNegatedMulAccReduction:
+  case ExpressionTypes::ExtMulAccReduction: {
+    bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
     return Ctx.TTI.getMulAccReductionCost(
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
             Instruction::ZExt,
-        RedTy, SrcVecTy, Ctx.CostKind);
+        RedTy, SrcVecTy, Negated, Ctx.CostKind);
+  }
   }
   llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
 }
@@ -2867,6 +2871,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << ")";
     break;
   }
+  case ExpressionTypes::ExtNegatedMulAccReduction: {
+    getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
+    O << " + ";
+    O << "reduce."
+      << Instruction::getOpcodeName(
+             RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
+      << " (sub (0, mul";
+    auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+    Mul->printFlags(O);
+    O << "(";
+    getOperand(0)->printAsOperand(O, SlotTracker);
+    auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+    O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
+      << *Ext0->getResultType() << "), (";
+    getOperand(1)->printAsOperand(O, SlotTracker);
+    auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+    O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
+      << *Ext1->getResultType() << ")";
+    if (Red->isConditional()) {
+      O << ", ";
+      Red->getCondOp()->printAsOperand(O, SlotTracker);
+    }
+    O << "))";
+    break;
+  }
   case ExpressionTypes::MulAccReduction:
   case ExpressionTypes::ExtMulAccReduction: {
     getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 6c5f9b7302292..132e22d0429e3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3158,16 +3158,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
 
   // Clamp the range if using multiply-accumulate-reduction is profitable.
   auto IsMulAccValidAndClampRange =
-      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
-          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+      [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
+          bool Negated = false) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           Type *SrcTy =
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
-          InstructionCost MulAccCost =
-              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+              IsZExt, RedTy, SrcVecTy, Negated, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);
           InstructionCost ExtCost = 0;
@@ -3185,14 +3186,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   };
 
   VPValue *VecOp = Red->getVecOp();
+  VPValue *Mul = nullptr;
+  VPValue *Sub = nullptr;
   VPValue *A, *B;
+  // Sub reductions will have a sub between the add reduction and vec op.
+  if (match(VecOp,
+            m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
+    Sub = VecOp;
+  else
+    Mul = VecOp;
   // Try to match reduce.add(mul(...)).
-  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+  if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     auto *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     auto *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+    auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
 
     // Match reduce.add(mul(ext, ext)).
     if (RecipeA && RecipeB &&
@@ -3201,12 +3210,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   Mul, RecipeA, RecipeB, nullptr)) {
-      return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
+                                   MulR, RecipeA, RecipeB, nullptr, Sub)) {
+      if (Sub)
+        return new VPExpressionRecipe(
+            RecipeA, RecipeB, MulR,
+            cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+      return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
     }
     // Match reduce.add(mul).
-    if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
-      return new VPExpressionRecipe(Mul, Red);
+    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+      return new VPExpressionRecipe(MulR, Red);
   }
   // Match reduce.add(ext(mul(ext(A), ext(B)))).
   // All extend recipes must have same opcode or A == B
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index c88ed95de2946..68f60e867490e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1468,8 +1468,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
-    CostAfterReduction =
-        TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+    CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
+                                                    ExtType, false, CostKind);
     return;
   }
   CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 4af3fa9202c77..8059ac12ecd2e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -416,3 +416,146 @@ exit:
   %r.0.lcssa = phi i64 [ %rdx.next, %loop ]
   ret i64 %r.0.lcssa
 }
+
+define i32 @print_mulacc_sub(ptr %a, ptr %b) {
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF
+; CHECK-NEXT: Live-in vp<%1> = VF * UF
+; CHECK-NEXT: Live-in vp<%2> = vector-trip-count
+; CHECK-NEXT: Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<entry>:
+; CHECK-NEXT: Successor(s): scalar.ph, vector.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT:   EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8>
+; CHECK-NEXT:     vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0>
+; CHECK-NEXT:     CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5>
+; CHECK-NEXT:     vp<%6> = vector-pointer ir<%gep.a>
+; CHECK-NEXT:     WIDEN ir<%load.a> = load vp<%6>
+; CHECK-NEXT:     CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
+; CHECK-NEXT:     vp<%7> = vector-pointer ir<%gep.b>
+; CHECK-NEXT:     WIDEN ir<%load.b> = load vp<%7>
+; CHECK-NEXT:     EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%2>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.exit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%10> from middle.block)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
+; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
+; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
+; CHECK-NEXT:   IR   %gep.b = getelementptr i8, ptr %b, i64 %iv
+; CHECK-NEXT:   IR   %load.b = load i8, ptr %gep.b, align 1
+; CHECK-NEXT:   IR   %ext.b = zext i8 %load.b to i32
+; CHECK-NEXT:   IR   %mul = mul i32 %ext.b, %ext.a
+; CHECK-NEXT:   IR   %add = sub i32 %accum, %mul
+; CHECK-NEXT:   IR   %iv.next = add i64 %iv, 1
+; CHECK-NEXT:   IR   %exitcond.not = icmp eq i64 %iv.next, 1024
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK:      VPlan 'Final VPlan for VF={4},UF={1}' {
+; CHECK-NEXT: Live-in ir<4> = VF * UF
+; CHECK-NEXT: Live-in ir<1024> = vector-trip-count
+; CHECK-NEXT: Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<entry>:
+; CHECK-NEXT: Successor(s): ir-bb<scalar.ph>, ir-bb<vector.ph>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<vector.ph>:
+; CHECK-NEXT: Successor(s): vector.body
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT:   EMIT-SCALAR vp<%index> = phi [ ir<0>, ir-bb<vector.ph> ], [ vp<%index.next>, vector.body ]
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add>.1
+; CHECK-NEXT:   CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%index>
+; CHECK-NEXT:   vp<%1> = vector-pointer ir<%gep.a>
+; CHECK-NEXT:   WIDEN ir<%load.a> = load vp<%1>
+; CHECK-NEXT:   CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%index>
+; CHECK-NEXT:   vp<%2> = vector-pointer ir<%gep.b>
+; CHECK-NEXT:   WIDEN ir<%load.b> = load vp<%2>
+; CHECK-NEXT:   WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32
+; CHECK-NEXT:   WIDEN-CAST ir<%ext.a> = zext ir<%load.a> to i32
+; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
+; CHECK-NEXT:   WIDEN ir<%add> = sub ir<0>, ir<%mul>
+; CHECK-NEXT:   REDUCE ir<%add>.1 = ir<%accum> + reduce.add (ir<%add>)
+; CHECK-NEXT:   EMIT vp<%index.next> = add nuw vp<%index>, ir<4>
+; CHECK-NEXT:   EMIT branch-on-count vp<%index.next>, ir<1024>
+; CHECK-NEXT: Successor(s): middle.block, vector.body
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%4> = compute-reduction-result ir<%accum>, ir<%add>.1
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<1024>, ir<1024>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>, ir-bb<scalar.ph>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.exit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%4> from middle.block)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<scalar.ph>:
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.resume.val> = phi [ ir<1024>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%4>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %scalar.ph ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from ir-bb<scalar.ph>)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %scalar.ph ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from ir-bb<scalar.ph>)
+; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
+; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
+; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
+; CHECK-NEXT:   IR   %gep.b = getelementptr i8, ptr %b, i64 %iv
+; CHECK-NEXT:   IR   %load.b = load i8, ptr %gep.b, align 1
+; CHECK-NEXT:   IR   %ext.b = zext i8 %load.b to i32
+; CHECK-NEXT:   IR   %mul = mul i32 %ext.b, %ext.a
+; CHECK-NEXT:   IR   %add = sub i32 %accum, %mul
+; CHECK-NEXT:   IR   %iv.next = add i64 %iv, 1
+; CHECK-NEXT:   IR   %exitcond.not = icmp eq i64 %iv.next, 1024
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  %gep.a = getelementptr i8, ptr %a, i64 %iv
+  %load.a = load i8, ptr %gep.a, align 1
+  %ext.a = zext i8 %load.a to i32
+  %gep.b = getelementptr i8, ptr %b, i64 %iv
+  %load.b = load i8, ptr %gep.b, align 1
+  %ext.b = zext i8 %load.b to i32
+  %mul = mul i32 %ext.b, %ext.a
+  %add = sub i32 %accum, %mul
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %for.exit, label %for.body
+
+for.exit:                        ; preds = %for.body
+  ret i32 %add
+}

>From 3171e7b8b86b819f97c6962a0978e03a13cf9e30 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 8 Jul 2025 15:25:46 +0100
Subject: [PATCH 02/10] Move IsNegated parameter and cost sub

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   |  2 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  4 ++--
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           | 14 ++++++++------
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  4 ++--
 .../Target/AArch64/AArch64TargetTransformInfo.cpp  |  8 ++++----
 .../Target/AArch64/AArch64TargetTransformInfo.h    |  2 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp     |  8 ++++----
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h       |  4 ++--
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp    |  9 ++++++---
 llvm/lib/Transforms/Vectorize/VPlan.h              |  4 +---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp     |  4 ++--
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp  |  2 +-
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp    |  4 ++--
 13 files changed, 36 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9165f4aacfd2b..eacfd833f3fa0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1654,7 +1654,7 @@ class TargetTransformInfo {
   /// The multiply can optionally be negated, which signifies that it is a sub
   /// reduction.
   LLVM_ABI InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
+      bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9d88d6a2c5bff..4f9d5e89a44c8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -971,8 +971,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
-                         bool Negated, TTI::TargetCostKind CostKind) const {
+  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+                         VectorType *Ty, TTI::TargetCostKind CostKind) const {
     return 1;
   }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 4b6983fbd69a4..4f0644330d7a9 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3260,14 +3260,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
-                         bool Negated,
+  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+                         VectorType *Ty,
                          TTI::TargetCostKind CostKind) const override {
-    if (Negated)
-      return InstructionCost::getInvalid(CostKind);
     // Without any native support, this is equivalent to the cost of
     // vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
-    // vecreduce.add(mul(A, B)).
+    // vecreduce.add(mul(A, B)) with an optional negation of the mul.
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
     InstructionCost RedCost = thisT()->getArithmeticReductionCost(
         Instruction::Add, ExtTy, std::nullopt, CostKind);
@@ -3277,8 +3275,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
     InstructionCost MulCost =
         thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
+    InstructionCost SubCost =
+        IsNegated
+            ? thisT()->getArithmeticInstrCost(Instruction::Sub, ExtTy, CostKind)
+            : 0;
 
-    return RedCost + MulCost + 2 * ExtCost;
+    return RedCost + SubCost + MulCost + 2 * ExtCost;
   }
 
   InstructionCost getVectorSplitCost() const { return 1; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index eceecda244762..9c3ef9c51cee9 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1283,9 +1283,9 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
 }
 
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
-    bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
+    bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
+  return TTIImpl->getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, Ty,
                                          CostKind);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6ce01767d22be..cc2a01198dbec 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5486,10 +5486,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
 }
 
 InstructionCost
-AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                       VectorType *VecTy, bool Negated,
+AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated,
+                                       Type *ResTy, VectorType *VecTy,
                                        TTI::TargetCostKind CostKind) const {
-  if (Negated)
+  if (IsNegated)
     return InstructionCost::getInvalid(CostKind);
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -5505,7 +5505,7 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return LT.first + 2;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
+  return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, VecTy,
                                        CostKind);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 611593e248aef..38bd48ac600cb 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -460,7 +460,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
 
   InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
+      bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
 
   InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 9821ffc4ffb29..a10210ebbdd30 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1916,10 +1916,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
 }
 
 InstructionCost
-ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                   VectorType *ValTy, bool Negated,
+ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+                                   VectorType *ValTy,
                                    TTI::TargetCostKind CostKind) const {
-  if (Negated)
+  if (IsNegated)
     return InstructionCost::getInvalid(CostKind);
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -1941,7 +1941,7 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
+  return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, ValTy,
                                        CostKind);
 }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 5a5d6755500df..fab5c915a6971 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -299,8 +299,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                            VectorType *ValTy, std::optional<FastMathFlags> FMF,
                            TTI::TargetCostKind CostKind) const override;
   InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
-                         bool Negated,
+  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+                         VectorType *ValTy,
                          TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index babb0b8e22040..db9a1aa8352f1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5414,7 +5414,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     InstructionCost RedCost = TTI.getMulAccReductionCost(
-        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
+        IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
+        CostKind);
 
     if (RedCost.isValid() &&
         RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5459,7 +5460,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
+          IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
+          CostKind);
       InstructionCost ExtraExtCost = 0;
       if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
         Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5478,7 +5480,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
+          true, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), VectorTy,
+          CostKind);
 
       if (RedCost.isValid() && RedCost < MulCost + BaseCost)
         return I == RetI ? RedCost : 0;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 2c79b9344168e..d49107b300ceb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2871,9 +2871,7 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
     MulAccReduction,
     /// Represent an inloop multiply-accumulate reduction, multiplying the
     /// extended vector operands, negating the multiplication, performing a
-    /// reduction.add
-    /// on the result, and adding
-    /// the scalar result to a chain.
+    /// reduction.add on the result, and adding the scalar result to a chain.
     ExtNegatedMulAccReduction,
   };
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 88bff35d68680..b374ef1341806 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2814,7 +2814,7 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
         RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
   }
   case ExpressionTypes::MulAccReduction:
-    return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+    return Ctx.TTI.getMulAccReductionCost(false, false, RedTy, SrcVecTy,
                                           Ctx.CostKind);
 
   case ExpressionTypes::ExtNegatedMulAccReduction:
@@ -2823,7 +2823,7 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
     return Ctx.TTI.getMulAccReductionCost(
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
             Instruction::ZExt,
-        RedTy, SrcVecTy, Negated, Ctx.CostKind);
+        Negated, RedTy, SrcVecTy, Ctx.CostKind);
   }
   }
   llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 132e22d0429e3..7d3f1357a1d02 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3168,7 +3168,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
           InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-              IsZExt, RedTy, SrcVecTy, Negated, CostKind);
+              IsZExt, Negated, RedTy, SrcVecTy, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);
           InstructionCost ExtCost = 0;
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 68f60e867490e..c5d327ec99016 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1468,8 +1468,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
-    CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
-                                                    ExtType, false, CostKind);
+    CostAfterReduction = TTI.getMulAccReductionCost(
+        IsUnsigned, /*IsNegated=*/false, II.getType(), ExtType, CostKind);
     return;
   }
   CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,

>From cacb89e9fb1a89078cb444cbcb961609d338fc45 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 13 Aug 2025 16:46:41 +0100
Subject: [PATCH 03/10] Improve getMulAccReductionCost comment

---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 4f0644330d7a9..1c70bd39732c3 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3265,7 +3265,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                          TTI::TargetCostKind CostKind) const override {
     // Without any native support, this is equivalent to the cost of
     // vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
-    // vecreduce.add(mul(A, B)) with an optional negation of the mul.
+    // vecreduce.add(mul(A, B)). IsNegated determines if the mul is negated.
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
     InstructionCost RedCost = thisT()->getArithmeticReductionCost(
         Instruction::Add, ExtTy, std::nullopt, CostKind);

>From 53fca5ccc92db4cae06f580d26d6c007fc52acd0 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Wed, 20 Aug 2025 15:41:25 +0100
Subject: [PATCH 04/10] Accept reductions with sub opcode

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7d3f1357a1d02..873496a7a8ad6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3151,7 +3151,7 @@ static VPExpressionRecipe *
 tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
                                           VPCostContext &Ctx, VFRange &Range) {
   unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
-  if (Opcode != Instruction::Add)
+  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
     return nullptr;
 
   Type *RedTy = Ctx.Types.inferScalarType(Red);

>From 55a9c3e47495e6f42f2f20aa67824e1e1a7e89e1 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Wed, 20 Aug 2025 20:12:48 +0100
Subject: [PATCH 05/10] Rebase and remove negated expression type

---
 .../llvm/Analysis/TargetTransformInfo.h       |  12 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +-
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  10 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../AArch64/AArch64TargetTransformInfo.cpp    |   6 +-
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +-
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |   8 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |   2 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    |   6 +-
 llvm/lib/Transforms/Vectorize/VPlan.h         |   9 --
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  38 +-----
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  24 +---
 .../Transforms/Vectorize/VectorCombine.cpp    |   2 +-
 .../vplan-printing-reductions.ll              | 127 +++++++++++++++---
 14 files changed, 145 insertions(+), 107 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index eacfd833f3fa0..b84f9d7775f4e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1647,14 +1647,12 @@ class TargetTransformInfo {
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
-  /// getArithmeticReductionCost of an Add reduction with multiply and optional
-  /// extensions. This is the cost of as:
-  /// ResTy vecreduce.add(mul (A, B)).
-  /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
-  /// The multiply can optionally be negated, which signifies that it is a sub
-  /// reduction.
+  /// getArithmeticReductionCost of an Add/Sub reduction with multiply and
+  /// optional extensions. This is the cost of as:
+  /// ResTy vecreduce.add/sub(mul (A, B)).
+  /// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
   LLVM_ABI InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4f9d5e89a44c8..9c2ebb1891cac 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -971,7 +971,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+  getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
                          VectorType *Ty, TTI::TargetCostKind CostKind) const {
     return 1;
   }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 1c70bd39732c3..823826edabc5a 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3260,7 +3260,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+  getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
                          VectorType *Ty,
                          TTI::TargetCostKind CostKind) const override {
     // Without any native support, this is equivalent to the cost of
@@ -3268,19 +3268,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // vecreduce.add(mul(A, B)). IsNegated determines if the mul is negated.
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
     InstructionCost RedCost = thisT()->getArithmeticReductionCost(
-        Instruction::Add, ExtTy, std::nullopt, CostKind);
+        RedOpcode, ExtTy, std::nullopt, CostKind);
     InstructionCost ExtCost = thisT()->getCastInstrCost(
         IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty,
         TTI::CastContextHint::None, CostKind);
 
     InstructionCost MulCost =
         thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
-    InstructionCost SubCost =
-        IsNegated
-            ? thisT()->getArithmeticInstrCost(Instruction::Sub, ExtTy, CostKind)
-            : 0;
 
-    return RedCost + SubCost + MulCost + 2 * ExtCost;
+    return RedCost + MulCost + 2 * ExtCost;
   }
 
   InstructionCost getVectorSplitCost() const { return 1; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 9c3ef9c51cee9..b4fa0d5964cb6 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1283,9 +1283,9 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
 }
 
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
-    bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
+    bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, Ty,
+  return TTIImpl->getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, Ty,
                                          CostKind);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cc2a01198dbec..2b09aa0f9c0cf 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5486,10 +5486,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
 }
 
 InstructionCost
-AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated,
+AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
                                        Type *ResTy, VectorType *VecTy,
                                        TTI::TargetCostKind CostKind) const {
-  if (IsNegated)
+  if (RedOpcode != Instruction::Add)
     return InstructionCost::getInvalid(CostKind);
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -5505,7 +5505,7 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated,
       return LT.first + 2;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, VecTy,
+  return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, VecTy,
                                        CostKind);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 38bd48ac600cb..b994ca74aa222 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -460,7 +460,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
 
   InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
 
   InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index a10210ebbdd30..9b250e6cac3ab 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1916,10 +1916,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
 }
 
 InstructionCost
-ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
-                                   VectorType *ValTy,
+ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
+                                   Type *ResTy, VectorType *ValTy,
                                    TTI::TargetCostKind CostKind) const {
-  if (IsNegated)
+  if (RedOpcode != Instruction::Add)
     return InstructionCost::getInvalid(CostKind);
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -1941,7 +1941,7 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, ValTy,
+  return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, ValTy,
                                        CostKind);
 }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index fab5c915a6971..0810c5532ed91 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -299,7 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                            VectorType *ValTy, std::optional<FastMathFlags> FMF,
                            TTI::TargetCostKind CostKind) const override;
   InstructionCost
-  getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
+  getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
                          VectorType *ValTy,
                          TTI::TargetCostKind CostKind) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index db9a1aa8352f1..7c43da0b4d552 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5414,7 +5414,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     InstructionCost RedCost = TTI.getMulAccReductionCost(
-        IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
+        IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
         CostKind);
 
     if (RedCost.isValid() &&
@@ -5460,7 +5460,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
+          IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
           CostKind);
       InstructionCost ExtraExtCost = 0;
       if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
@@ -5480,7 +5480,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          true, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), VectorTy,
+          true, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), VectorTy,
           CostKind);
 
       if (RedCost.isValid() && RedCost < MulCost + BaseCost)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d49107b300ceb..8ed26f23c859b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2869,10 +2869,6 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
     /// vector operands, performing a reduction.add on the result, and adding
     /// the scalar result to a chain.
     MulAccReduction,
-    /// Represent an inloop multiply-accumulate reduction, multiplying the
-    /// extended vector operands, negating the multiplication, performing a
-    /// reduction.add on the result, and adding the scalar result to a chain.
-    ExtNegatedMulAccReduction,
   };
 
   /// Type of the expression.
@@ -2896,11 +2892,6 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
                      VPWidenRecipe *Mul, VPReductionRecipe *Red)
       : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
                            {Ext0, Ext1, Mul, Red}) {}
-  VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
-                     VPWidenRecipe *Mul, VPWidenRecipe *Sub,
-                     VPReductionRecipe *Red)
-      : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
-                           {Ext0, Ext1, Mul, Sub, Red}) {}
 
   ~VPExpressionRecipe() override {
     for (auto *R : reverse(ExpressionRecipes))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b374ef1341806..93e97faaefb4e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2803,10 +2803,10 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
       toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF));
   assert(RedTy->isIntegerTy() &&
          "VPExpressionRecipe only supports integer types currently.");
+  unsigned Opcode = RecurrenceDescriptor::getOpcode(
+      cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
   switch (ExpressionType) {
   case ExpressionTypes::ExtendedReduction: {
-    unsigned Opcode = RecurrenceDescriptor::getOpcode(
-        cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
     return Ctx.TTI.getExtendedReductionCost(
         Opcode,
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
@@ -2814,17 +2814,14 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
         RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
   }
   case ExpressionTypes::MulAccReduction:
-    return Ctx.TTI.getMulAccReductionCost(false, false, RedTy, SrcVecTy,
+    return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
                                           Ctx.CostKind);
 
-  case ExpressionTypes::ExtNegatedMulAccReduction:
-  case ExpressionTypes::ExtMulAccReduction: {
-    bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
+  case ExpressionTypes::ExtMulAccReduction:
     return Ctx.TTI.getMulAccReductionCost(
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
             Instruction::ZExt,
-        Negated, RedTy, SrcVecTy, Ctx.CostKind);
-  }
+        Opcode, RedTy, SrcVecTy, Ctx.CostKind);
   }
   llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
 }
@@ -2871,31 +2868,6 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << ")";
     break;
   }
-  case ExpressionTypes::ExtNegatedMulAccReduction: {
-    getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
-    O << " + ";
-    O << "reduce."
-      << Instruction::getOpcodeName(
-             RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
-      << " (sub (0, mul";
-    auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
-    Mul->printFlags(O);
-    O << "(";
-    getOperand(0)->printAsOperand(O, SlotTracker);
-    auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
-    O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
-      << *Ext0->getResultType() << "), (";
-    getOperand(1)->printAsOperand(O, SlotTracker);
-    auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
-    O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
-      << *Ext1->getResultType() << ")";
-    if (Red->isConditional()) {
-      O << ", ";
-      Red->getCondOp()->printAsOperand(O, SlotTracker);
-    }
-    O << "))";
-    break;
-  }
   case ExpressionTypes::MulAccReduction:
   case ExpressionTypes::ExtMulAccReduction: {
     getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 873496a7a8ad6..8cb1e55abdaeb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3160,7 +3160,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   auto IsMulAccValidAndClampRange =
       [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
           VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
-          bool Negated = false) -> bool {
+          unsigned Opcode) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
@@ -3168,7 +3168,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
           InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-              IsZExt, Negated, RedTy, SrcVecTy, CostKind);
+              IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);
           InstructionCost ExtCost = 0;
@@ -3186,15 +3186,8 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   };
 
   VPValue *VecOp = Red->getVecOp();
-  VPValue *Mul = nullptr;
-  VPValue *Sub = nullptr;
+  VPValue *Mul = VecOp;
   VPValue *A, *B;
-  // Sub reductions will have a sub between the add reduction and vec op.
-  if (match(VecOp,
-            m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
-    Sub = VecOp;
-  else
-    Mul = VecOp;
   // Try to match reduce.add(mul(...)).
   if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     auto *RecipeA =
@@ -3210,15 +3203,12 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   MulR, RecipeA, RecipeB, nullptr, Sub)) {
-      if (Sub)
-        return new VPExpressionRecipe(
-            RecipeA, RecipeB, MulR,
-            cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+                                   MulR, RecipeA, RecipeB, nullptr, Opcode)) {
       return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
     }
     // Match reduce.add(mul).
-    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr,
+                                   Opcode))
       return new VPExpressionRecipe(MulR, Red);
   }
   // Match reduce.add(ext(mul(ext(A), ext(B)))).
@@ -3236,7 +3226,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         Ext0->getOpcode() == Ext1->getOpcode() &&
         IsMulAccValidAndClampRange(Ext0->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   Mul, Ext0, Ext1, Ext)) {
+                                   Mul, Ext0, Ext1, Ext, Opcode)) {
       auto *NewExt0 = new VPWidenCastRecipe(
           Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
           Ext0->getDebugLoc());
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index c5d327ec99016..bc93cc6ab725a 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1469,7 +1469,7 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
 
     CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
     CostAfterReduction = TTI.getMulAccReductionCost(
-        IsUnsigned, /*IsNegated=*/false, II.getType(), ExtType, CostKind);
+        IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
     return;
   }
   CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 8059ac12ecd2e..1409df4df90e8 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -417,7 +417,53 @@ exit:
   ret i64 %r.0.lcssa
 }
 
+define i64 @print_extended_sub_reduction(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_extended_sub_reduction'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<[[VF:%.+]]> = VF
+; CHECK-NEXT: Live-in vp<[[VFxUF:%.+]]> = VF * UF
+; CHECK-NEXT: Live-in vp<[[VTC:%.+]]> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK:      vector.ph:
+; CHECK-NEXT:   EMIT vp<[[RDX_START:%.+]]> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<[[IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[IV_NEXT:%.+]]>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<[[RDX:%.+]]> = phi vp<[[RDX_START]]>, vp<[[RDX_NEXT:%.+]]>
+; CHECK-NEXT:     vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[IV]]>, ir<1>
+; CHECK-NEXT:     CLONE ir<%arrayidx> = getelementptr inbounds ir<%x>, vp<[[STEPS]]>
+; CHECK-NEXT:     vp<[[ADDR:%.+]]> = vector-pointer ir<%arrayidx>
+; CHECK-NEXT:     WIDEN ir<[[LOAD:%.+]]> = load vp<[[ADDR]]>
+; CHECK-NEXT:     EXPRESSION vp<[[RDX_NEXT]]> = ir<[[RDX]]> + reduce.sub (ir<[[LOAD]]> zext to i64)
+; CHECK-NEXT:     EMIT vp<[[IV_NEXT]]> = add nuw vp<[[IV]]>, vp<[[VFxUF]]>
+; CHECK-NEXT:     EMIT branch-on-count vp<[[IV_NEXT]]>, vp<[[VTC]]>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+;
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i32 [ %iv.next, %loop ], [ 0, %entry ]
+  %rdx = phi i64 [ %rdx.next, %loop ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %x, i32 %iv
+  %load0 = load i32, ptr %arrayidx, align 4
+  %conv0 = zext i32 %load0 to i64
+  %rdx.next = sub nsw i64 %rdx, %conv0
+  %iv.next = add nuw nsw i32 %iv, 1
+  %exitcond = icmp eq i32 %iv.next, %n
+  br i1 %exitcond, label %exit, label %loop
+
+exit:
+  %r.0.lcssa = phi i64 [ %rdx.next, %loop ]
+  ret i64 %r.0.lcssa
+}
+
 define i32 @print_mulacc_sub(ptr %a, ptr %b) {
+; CHECK-LABEL: 'print_mulacc_sub'
 ; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
 ; CHECK-NEXT: Live-in vp<%0> = VF
 ; CHECK-NEXT: Live-in vp<%1> = VF * UF
@@ -442,7 +488,7 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT:     CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
 ; CHECK-NEXT:     vp<%7> = vector-pointer ir<%gep.b>
 ; CHECK-NEXT:     WIDEN ir<%load.b> = load vp<%7>
-; CHECK-NEXT:     EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
+; CHECK-NEXT:     EXPRESSION vp<%8> = ir<%accum> + reduce.sub (mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32))
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%2>
 ; CHECK-NEXT:   No successors
@@ -480,7 +526,6 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT: No successors
 ; CHECK-NEXT: }
 ; CHECK:      VPlan 'Final VPlan for VF={4},UF={1}' {
-; CHECK-NEXT: Live-in ir<4> = VF * UF
 ; CHECK-NEXT: Live-in ir<1024> = vector-trip-count
 ; CHECK-NEXT: Live-in ir<1024> = original trip-count
 ; CHECK-EMPTY:
@@ -492,40 +537,33 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT: vector.body:
 ; CHECK-NEXT:   EMIT-SCALAR vp<%index> = phi [ ir<0>, ir-bb<vector.ph> ], [ vp<%index.next>, vector.body ]
-; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add>.1
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add>
 ; CHECK-NEXT:   CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%index>
-; CHECK-NEXT:   vp<%1> = vector-pointer ir<%gep.a>
-; CHECK-NEXT:   WIDEN ir<%load.a> = load vp<%1>
+; CHECK-NEXT:   WIDEN ir<%load.a> = load ir<%gep.a>
 ; CHECK-NEXT:   CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%index>
-; CHECK-NEXT:   vp<%2> = vector-pointer ir<%gep.b>
-; CHECK-NEXT:   WIDEN ir<%load.b> = load vp<%2>
+; CHECK-NEXT:   WIDEN ir<%load.b> = load ir<%gep.b>
 ; CHECK-NEXT:   WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32
 ; CHECK-NEXT:   WIDEN-CAST ir<%ext.a> = zext ir<%load.a> to i32
 ; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
-; CHECK-NEXT:   WIDEN ir<%add> = sub ir<0>, ir<%mul>
-; CHECK-NEXT:   REDUCE ir<%add>.1 = ir<%accum> + reduce.add (ir<%add>)
+; CHECK-NEXT:   REDUCE ir<%add> = ir<%accum> + reduce.sub (ir<%mul>)
 ; CHECK-NEXT:   EMIT vp<%index.next> = add nuw vp<%index>, ir<4>
 ; CHECK-NEXT:   EMIT branch-on-count vp<%index.next>, ir<1024>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<%4> = compute-reduction-result ir<%accum>, ir<%add>.1
-; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<1024>, ir<1024>
-; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
-; CHECK-NEXT: Successor(s): ir-bb<for.exit>, ir-bb<scalar.ph>
+; CHECK-NEXT:   EMIT vp<%2> = compute-reduction-result ir<%accum>, ir<%add>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.exit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%4> from middle.block)
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%2> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<scalar.ph>:
-; CHECK-NEXT:   EMIT-SCALAR vp<%bc.resume.val> = phi [ ir<1024>, middle.block ], [ ir<0>, ir-bb<entry> ]
-; CHECK-NEXT:   EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%4>, middle.block ], [ ir<0>, ir-bb<entry> ]
 ; CHECK-NEXT: Successor(s): ir-bb<for.body>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<for.body>:
-; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %scalar.ph ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from ir-bb<scalar.ph>)
-; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %scalar.ph ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from ir-bb<scalar.ph>)
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %scalar.ph ], [ %iv.next, %for.body ] (extra operand: ir<0> from ir-bb<scalar.ph>)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %scalar.ph ], [ %add, %for.body ] (extra operand: ir<0> from ir-bb<scalar.ph>)
 ; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
 ; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
 ; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
@@ -559,3 +597,56 @@ for.body:                                         ; preds = %for.body, %entry
 for.exit:                        ; preds = %for.body
   ret i32 %add
 }
+
+define i64 @print_mulacc_sub_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
+; CHECK-LABEL: 'print_mulacc_sub_extended'
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<[[VF:%.+]]> = VF
+; CHECK-NEXT: Live-in vp<[[VFxUF:%.+]]> = VF * UF
+; CHECK-NEXT: Live-in vp<[[VTC:%.+]]> = vector-trip-count
+; CHECK-NEXT: Live-in ir<%n> = original trip-count
+; CHECK-EMPTY:
+; CHECK:      vector.ph:
+; CHECK-NEXT:   EMIT vp<[[RDX_START:%.+]]> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<[[IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[IV_NEXT:%.+]]>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<[[RDX:%.+]]> = phi vp<[[RDX_START]]>, vp<[[RDX_NEXT:%.+]]>
+; CHECK-NEXT:     vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[IV]]>, ir<1>
+; CHECK-NEXT:     CLONE ir<[[ARRAYIDX0:%.+]]> = getelementptr inbounds ir<%x>, vp<[[STEPS]]>
+; CHECK-NEXT:     vp<[[ADDR0:%.+]]> = vector-pointer ir<[[ARRAYIDX0]]>
+; CHECK-NEXT:     WIDEN ir<[[LOAD0:%.+]]> = load vp<[[ADDR0]]>
+; CHECK-NEXT:     CLONE ir<[[ARRAYIDX1:%.+]]> = getelementptr inbounds ir<%y>, vp<[[STEPS]]>
+; CHECK-NEXT:     vp<[[ADDR1:%.+]]> = vector-pointer ir<[[ARRAYIDX1]]>
+; CHECK-NEXT:     WIDEN ir<[[LOAD1:%.+]]> = load vp<[[ADDR1]]>
+; CHECK-NEXT:     EXPRESSION vp<[[RDX_NEXT:%.+]]> = ir<[[RDX]]> + reduce.sub (mul nsw (ir<[[LOAD0]]> sext to i64), (ir<[[LOAD1]]> sext to i64))
+; CHECK-NEXT:     EMIT vp<[[IV_NEXT]]> = add nuw vp<[[IV]]>, vp<[[VFxUF]]>
+; CHECK-NEXT:     EMIT branch-on-count vp<[[IV_NEXT]]>, vp<[[VTC]]>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+;
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i32 [ %iv.next, %loop ], [ 0, %entry ]
+  %rdx = phi i64 [ %rdx.next, %loop ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %x, i32 %iv
+  %load0 = load i16, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i16, ptr %y, i32 %iv
+  %load1 = load i16, ptr %arrayidx1, align 4
+  %conv0 = sext i16 %load0 to i32
+  %conv1 = sext i16 %load1 to i32
+  %mul = mul nsw i32 %conv0, %conv1
+  %conv = sext i32 %mul to i64
+  %rdx.next = sub nsw i64 %rdx, %conv
+  %iv.next = add nuw nsw i32 %iv, 1
+  %exitcond = icmp eq i32 %iv.next, %n
+  br i1 %exitcond, label %exit, label %loop
+
+exit:
+  %r.0.lcssa = phi i64 [ %rdx.next, %loop ]
+  ret i64 %r.0.lcssa
+}

>From 27e462a99702206329cd969624b84d3fddede171 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 21 Aug 2025 16:24:09 +0100
Subject: [PATCH 06/10] Rename printing test blocks

---
 .../vplan-printing-reductions.ll              | 40 +++++++++----------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 1409df4df90e8..c37113e517cb3 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -499,20 +499,20 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT:   EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
 ; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
-; CHECK-NEXT: Successor(s): ir-bb<for.exit>, scalar.ph
+; CHECK-NEXT: Successor(s): ir-bb<exit>, scalar.ph
 ; CHECK-EMPTY:
-; CHECK-NEXT: ir-bb<for.exit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%10> from middle.block)
+; CHECK-NEXT: ir-bb<exit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%10> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: scalar.ph:
 ; CHECK-NEXT:   EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
 ; CHECK-NEXT:   EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
-; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-NEXT: Successor(s): ir-bb<loop>
 ; CHECK-EMPTY:
-; CHECK-NEXT: ir-bb<for.body>:
-; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from scalar.ph)
-; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT: ir-bb<loop>:
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %entry ], [ %add, %loop ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
 ; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
 ; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
 ; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
@@ -552,18 +552,18 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
 ; CHECK-NEXT:   EMIT vp<%2> = compute-reduction-result ir<%accum>, ir<%add>
-; CHECK-NEXT: Successor(s): ir-bb<for.exit>
+; CHECK-NEXT: Successor(s): ir-bb<exit>
 ; CHECK-EMPTY:
-; CHECK-NEXT: ir-bb<for.exit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%2> from middle.block)
+; CHECK-NEXT: ir-bb<exit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%2> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<scalar.ph>:
-; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-NEXT: Successor(s): ir-bb<loop>
 ; CHECK-EMPTY:
-; CHECK-NEXT: ir-bb<for.body>:
-; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %scalar.ph ], [ %iv.next, %for.body ] (extra operand: ir<0> from ir-bb<scalar.ph>)
-; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %scalar.ph ], [ %add, %for.body ] (extra operand: ir<0> from ir-bb<scalar.ph>)
+; CHECK-NEXT: ir-bb<loop>:
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %scalar.ph ], [ %iv.next, %loop ] (extra operand: ir<0> from ir-bb<scalar.ph>)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %scalar.ph ], [ %add, %loop ] (extra operand: ir<0> from ir-bb<scalar.ph>)
 ; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
 ; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
 ; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
@@ -577,11 +577,11 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT: No successors
 ; CHECK-NEXT: }
 entry:
-  br label %for.body
+  br label %loop
 
-for.body:                                         ; preds = %for.body, %entry
-  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
-  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
+loop:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
+  %accum = phi i32 [ 0, %entry ], [ %add, %loop ]
   %gep.a = getelementptr i8, ptr %a, i64 %iv
   %load.a = load i8, ptr %gep.a, align 1
   %ext.a = zext i8 %load.a to i32
@@ -592,9 +592,9 @@ for.body:                                         ; preds = %for.body, %entry
   %add = sub i32 %accum, %mul
   %iv.next = add i64 %iv, 1
   %exitcond.not = icmp eq i64 %iv.next, 1024
-  br i1 %exitcond.not, label %for.exit, label %for.body
+  br i1 %exitcond.not, label %exit, label %loop
 
-for.exit:                        ; preds = %for.body
+exit:
   ret i32 %add
 }
 

>From 67a06044d46e907e828be98487276f6818ed2a73 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 27 Aug 2025 14:57:11 +0100
Subject: [PATCH 07/10] Address review

---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |   4 +-
 .../AArch64/AArch64TargetTransformInfo.cpp    |   5 +-
 .../Transforms/Vectorize/VPlanTransforms.cpp  |   5 +-
 .../SLPVectorizer/AArch64/vecreduceadd.ll     | 261 ++++++++++++++++++
 4 files changed, 268 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 823826edabc5a..dce423fc1b18b 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3265,7 +3265,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                          TTI::TargetCostKind CostKind) const override {
     // Without any native support, this is equivalent to the cost of
     // vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
-    // vecreduce.add(mul(A, B)). IsNegated determines if the mul is negated.
+    // vecreduce.add(mul(A, B)).
+    assert((RedOpcode == Instruction::Add || RedOpcode == Instruction::Sub) &&
+           "The reduction opcode is expected to be Add or Sub.");
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
     InstructionCost RedCost = thisT()->getArithmeticReductionCost(
         RedOpcode, ExtTy, std::nullopt, CostKind);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 2b09aa0f9c0cf..922da10f4e39f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5489,12 +5489,11 @@ InstructionCost
 AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
                                        Type *ResTy, VectorType *VecTy,
                                        TTI::TargetCostKind CostKind) const {
-  if (RedOpcode != Instruction::Add)
-    return InstructionCost::getInvalid(CostKind);
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
-  if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
+  if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple() &&
+      RedOpcode == Instruction::Add) {
     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
 
     // The legal cases with dotprod are
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 8cb1e55abdaeb..2b377e4010b99 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3186,15 +3186,14 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   };
 
   VPValue *VecOp = Red->getVecOp();
-  VPValue *Mul = VecOp;
   VPValue *A, *B;
   // Try to match reduce.add(mul(...)).
-  if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     auto *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     auto *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
+    auto *MulR = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
 
     // Match reduce.add(mul(ext, ext)).
     if (RecipeA && RecipeB &&
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/vecreduceadd.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/vecreduceadd.ll
index 36826eb6681c8..c1a87f0c5f907 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/vecreduceadd.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/vecreduceadd.ll
@@ -1149,3 +1149,264 @@ entry:
   %add.15 = add nsw i32 %mul.15, %add.14
   ret i32 %add.15
 }
+
+; COST-LABEL: Function:  mla_v16i8_i32_sub
+; COST: Cost:            '-2'
+define i32 @mla_v16i8_i32_sub(ptr %x, ptr %y) "target-features"="+dotprod" {
+; CHECK-LABEL: @mla_v16i8_i32_sub(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load i8, ptr [[X:%.*]], align 1
+; CHECK-NEXT:    [[CONV:%.*]] = sext i8 [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, ptr [[Y:%.*]], align 1
+; CHECK-NEXT:    [[CONV3:%.*]] = sext i8 [[TMP1]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[CONV3]], [[CONV]]
+; CHECK-NEXT:    [[ARRAYIDX_1:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 1
+; CHECK-NEXT:    [[TMP2:%.*]] = load i8, ptr [[ARRAYIDX_1]], align 1
+; CHECK-NEXT:    [[CONV_1:%.*]] = sext i8 [[TMP2]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_1:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX2_1]], align 1
+; CHECK-NEXT:    [[CONV3_1:%.*]] = sext i8 [[TMP3]] to i32
+; CHECK-NEXT:    [[MUL_1:%.*]] = mul nsw i32 [[CONV3_1]], [[CONV_1]]
+; CHECK-NEXT:    [[SUB_1:%.*]] = sub nsw i32 [[MUL_1]], [[MUL]]
+; CHECK-NEXT:    [[ARRAYIDX_2:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 2
+; CHECK-NEXT:    [[TMP4:%.*]] = load i8, ptr [[ARRAYIDX_2]], align 1
+; CHECK-NEXT:    [[CONV_2:%.*]] = sext i8 [[TMP4]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_2:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 2
+; CHECK-NEXT:    [[TMP5:%.*]] = load i8, ptr [[ARRAYIDX2_2]], align 1
+; CHECK-NEXT:    [[CONV3_2:%.*]] = sext i8 [[TMP5]] to i32
+; CHECK-NEXT:    [[MUL_2:%.*]] = mul nsw i32 [[CONV3_2]], [[CONV_2]]
+; CHECK-NEXT:    [[SUB_2:%.*]] = sub nsw i32 [[MUL_2]], [[SUB_1]]
+; CHECK-NEXT:    [[ARRAYIDX_3:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 3
+; CHECK-NEXT:    [[TMP6:%.*]] = load i8, ptr [[ARRAYIDX_3]], align 1
+; CHECK-NEXT:    [[CONV_3:%.*]] = sext i8 [[TMP6]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_3:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 3
+; CHECK-NEXT:    [[TMP7:%.*]] = load i8, ptr [[ARRAYIDX2_3]], align 1
+; CHECK-NEXT:    [[CONV3_3:%.*]] = sext i8 [[TMP7]] to i32
+; CHECK-NEXT:    [[MUL_3:%.*]] = mul nsw i32 [[CONV3_3]], [[CONV_3]]
+; CHECK-NEXT:    [[SUB_3:%.*]] = sub nsw i32 [[MUL_3]], [[SUB_2]]
+; CHECK-NEXT:    [[ARRAYIDX_4:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 4
+; CHECK-NEXT:    [[TMP8:%.*]] = load i8, ptr [[ARRAYIDX_4]], align 1
+; CHECK-NEXT:    [[CONV_4:%.*]] = sext i8 [[TMP8]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_4:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 4
+; CHECK-NEXT:    [[TMP9:%.*]] = load i8, ptr [[ARRAYIDX2_4]], align 1
+; CHECK-NEXT:    [[CONV3_4:%.*]] = sext i8 [[TMP9]] to i32
+; CHECK-NEXT:    [[MUL_4:%.*]] = mul nsw i32 [[CONV3_4]], [[CONV_4]]
+; CHECK-NEXT:    [[SUB_4:%.*]] = sub nsw i32 [[MUL_4]], [[SUB_3]]
+; CHECK-NEXT:    [[ARRAYIDX_5:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 5
+; CHECK-NEXT:    [[TMP10:%.*]] = load i8, ptr [[ARRAYIDX_5]], align 1
+; CHECK-NEXT:    [[CONV_5:%.*]] = sext i8 [[TMP10]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_5:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 5
+; CHECK-NEXT:    [[TMP11:%.*]] = load i8, ptr [[ARRAYIDX2_5]], align 1
+; CHECK-NEXT:    [[CONV3_5:%.*]] = sext i8 [[TMP11]] to i32
+; CHECK-NEXT:    [[MUL_5:%.*]] = mul nsw i32 [[CONV3_5]], [[CONV_5]]
+; CHECK-NEXT:    [[SUB_5:%.*]] = sub nsw i32 [[MUL_5]], [[SUB_4]]
+; CHECK-NEXT:    [[ARRAYIDX_6:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 6
+; CHECK-NEXT:    [[TMP12:%.*]] = load i8, ptr [[ARRAYIDX_6]], align 1
+; CHECK-NEXT:    [[CONV_6:%.*]] = sext i8 [[TMP12]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_6:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 6
+; CHECK-NEXT:    [[TMP13:%.*]] = load i8, ptr [[ARRAYIDX2_6]], align 1
+; CHECK-NEXT:    [[CONV3_6:%.*]] = sext i8 [[TMP13]] to i32
+; CHECK-NEXT:    [[MUL_6:%.*]] = mul nsw i32 [[CONV3_6]], [[CONV_6]]
+; CHECK-NEXT:    [[SUB_6:%.*]] = sub nsw i32 [[MUL_6]], [[SUB_5]]
+; CHECK-NEXT:    [[ARRAYIDX_7:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 7
+; CHECK-NEXT:    [[TMP14:%.*]] = load i8, ptr [[ARRAYIDX_7]], align 1
+; CHECK-NEXT:    [[CONV_7:%.*]] = sext i8 [[TMP14]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_7:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 7
+; CHECK-NEXT:    [[TMP15:%.*]] = load i8, ptr [[ARRAYIDX2_7]], align 1
+; CHECK-NEXT:    [[CONV3_7:%.*]] = sext i8 [[TMP15]] to i32
+; CHECK-NEXT:    [[MUL_7:%.*]] = mul nsw i32 [[CONV3_7]], [[CONV_7]]
+; CHECK-NEXT:    [[SUB_7:%.*]] = sub nsw i32 [[MUL_7]], [[SUB_6]]
+; CHECK-NEXT:    [[ARRAYIDX_8:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 8
+; CHECK-NEXT:    [[TMP16:%.*]] = load i8, ptr [[ARRAYIDX_8]], align 1
+; CHECK-NEXT:    [[CONV_8:%.*]] = sext i8 [[TMP16]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_8:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 8
+; CHECK-NEXT:    [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX2_8]], align 1
+; CHECK-NEXT:    [[CONV3_8:%.*]] = sext i8 [[TMP17]] to i32
+; CHECK-NEXT:    [[MUL_8:%.*]] = mul nsw i32 [[CONV3_8]], [[CONV_8]]
+; CHECK-NEXT:    [[SUB_8:%.*]] = sub nsw i32 [[MUL_8]], [[SUB_7]]
+; CHECK-NEXT:    [[ARRAYIDX_9:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 9
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX_9]], align 1
+; CHECK-NEXT:    [[CONV_9:%.*]] = sext i8 [[TMP18]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_9:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 9
+; CHECK-NEXT:    [[TMP19:%.*]] = load i8, ptr [[ARRAYIDX2_9]], align 1
+; CHECK-NEXT:    [[CONV3_9:%.*]] = sext i8 [[TMP19]] to i32
+; CHECK-NEXT:    [[MUL_9:%.*]] = mul nsw i32 [[CONV3_9]], [[CONV_9]]
+; CHECK-NEXT:    [[SUB_9:%.*]] = sub nsw i32 [[MUL_9]], [[SUB_8]]
+; CHECK-NEXT:    [[ARRAYIDX_10:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 10
+; CHECK-NEXT:    [[TMP20:%.*]] = load i8, ptr [[ARRAYIDX_10]], align 1
+; CHECK-NEXT:    [[CONV_10:%.*]] = sext i8 [[TMP20]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_10:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 10
+; CHECK-NEXT:    [[TMP21:%.*]] = load i8, ptr [[ARRAYIDX2_10]], align 1
+; CHECK-NEXT:    [[CONV3_10:%.*]] = sext i8 [[TMP21]] to i32
+; CHECK-NEXT:    [[MUL_10:%.*]] = mul nsw i32 [[CONV3_10]], [[CONV_10]]
+; CHECK-NEXT:    [[SUB_10:%.*]] = sub nsw i32 [[MUL_10]], [[SUB_9]]
+; CHECK-NEXT:    [[ARRAYIDX_11:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 11
+; CHECK-NEXT:    [[TMP22:%.*]] = load i8, ptr [[ARRAYIDX_11]], align 1
+; CHECK-NEXT:    [[CONV_11:%.*]] = sext i8 [[TMP22]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_11:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 11
+; CHECK-NEXT:    [[TMP23:%.*]] = load i8, ptr [[ARRAYIDX2_11]], align 1
+; CHECK-NEXT:    [[CONV3_11:%.*]] = sext i8 [[TMP23]] to i32
+; CHECK-NEXT:    [[MUL_11:%.*]] = mul nsw i32 [[CONV3_11]], [[CONV_11]]
+; CHECK-NEXT:    [[SUB_11:%.*]] = sub nsw i32 [[MUL_11]], [[SUB_10]]
+; CHECK-NEXT:    [[ARRAYIDX_12:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 12
+; CHECK-NEXT:    [[TMP24:%.*]] = load i8, ptr [[ARRAYIDX_12]], align 1
+; CHECK-NEXT:    [[CONV_12:%.*]] = sext i8 [[TMP24]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_12:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 12
+; CHECK-NEXT:    [[TMP25:%.*]] = load i8, ptr [[ARRAYIDX2_12]], align 1
+; CHECK-NEXT:    [[CONV3_12:%.*]] = sext i8 [[TMP25]] to i32
+; CHECK-NEXT:    [[MUL_12:%.*]] = mul nsw i32 [[CONV3_12]], [[CONV_12]]
+; CHECK-NEXT:    [[SUB_12:%.*]] = sub nsw i32 [[MUL_12]], [[SUB_11]]
+; CHECK-NEXT:    [[ARRAYIDX_13:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 13
+; CHECK-NEXT:    [[TMP26:%.*]] = load i8, ptr [[ARRAYIDX_13]], align 1
+; CHECK-NEXT:    [[CONV_13:%.*]] = sext i8 [[TMP26]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_13:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 13
+; CHECK-NEXT:    [[TMP27:%.*]] = load i8, ptr [[ARRAYIDX2_13]], align 1
+; CHECK-NEXT:    [[CONV3_13:%.*]] = sext i8 [[TMP27]] to i32
+; CHECK-NEXT:    [[MUL_13:%.*]] = mul nsw i32 [[CONV3_13]], [[CONV_13]]
+; CHECK-NEXT:    [[SUB_13:%.*]] = sub nsw i32 [[MUL_13]], [[SUB_12]]
+; CHECK-NEXT:    [[ARRAYIDX_14:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 14
+; CHECK-NEXT:    [[TMP28:%.*]] = load i8, ptr [[ARRAYIDX_14]], align 1
+; CHECK-NEXT:    [[CONV_14:%.*]] = sext i8 [[TMP28]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_14:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 14
+; CHECK-NEXT:    [[TMP29:%.*]] = load i8, ptr [[ARRAYIDX2_14]], align 1
+; CHECK-NEXT:    [[CONV3_14:%.*]] = sext i8 [[TMP29]] to i32
+; CHECK-NEXT:    [[MUL_14:%.*]] = mul nsw i32 [[CONV3_14]], [[CONV_14]]
+; CHECK-NEXT:    [[SUB_14:%.*]] = sub nsw i32 [[MUL_14]], [[SUB_13]]
+; CHECK-NEXT:    [[ARRAYIDX_15:%.*]] = getelementptr inbounds nuw i8, ptr [[X]], i64 15
+; CHECK-NEXT:    [[TMP30:%.*]] = load i8, ptr [[ARRAYIDX_15]], align 1
+; CHECK-NEXT:    [[CONV_15:%.*]] = sext i8 [[TMP30]] to i32
+; CHECK-NEXT:    [[ARRAYIDX2_15:%.*]] = getelementptr inbounds nuw i8, ptr [[Y]], i64 15
+; CHECK-NEXT:    [[TMP31:%.*]] = load i8, ptr [[ARRAYIDX2_15]], align 1
+; CHECK-NEXT:    [[CONV3_15:%.*]] = sext i8 [[TMP31]] to i32
+; CHECK-NEXT:    [[MUL_15:%.*]] = mul nsw i32 [[CONV3_15]], [[CONV_15]]
+; CHECK-NEXT:    [[SUB_15:%.*]] = sub nsw i32 [[MUL_15]], [[SUB_14]]
+; CHECK-NEXT:    ret i32 [[SUB_15]]
+;
+entry:
+  %0 = load i8, ptr %x
+  %conv = sext i8 %0 to i32
+  %1 = load i8, ptr %y
+  %conv3 = sext i8 %1 to i32
+  %mul = mul nsw i32 %conv3, %conv
+  %arrayidx.1 = getelementptr inbounds nuw i8, ptr %x, i64 1
+  %2 = load i8, ptr %arrayidx.1
+  %conv.1 = sext i8 %2 to i32
+  %arrayidx2.1 = getelementptr inbounds nuw i8, ptr %y, i64 1
+  %3 = load i8, ptr %arrayidx2.1
+  %conv3.1 = sext i8 %3 to i32
+  %mul.1 = mul nsw i32 %conv3.1, %conv.1
+  %sub.1 = sub nsw i32 %mul.1, %mul
+  %arrayidx.2 = getelementptr inbounds nuw i8, ptr %x, i64 2
+  %4 = load i8, ptr %arrayidx.2
+  %conv.2 = sext i8 %4 to i32
+  %arrayidx2.2 = getelementptr inbounds nuw i8, ptr %y, i64 2
+  %5 = load i8, ptr %arrayidx2.2
+  %conv3.2 = sext i8 %5 to i32
+  %mul.2 = mul nsw i32 %conv3.2, %conv.2
+  %sub.2 = sub nsw i32 %mul.2, %sub.1
+  %arrayidx.3 = getelementptr inbounds nuw i8, ptr %x, i64 3
+  %6 = load i8, ptr %arrayidx.3
+  %conv.3 = sext i8 %6 to i32
+  %arrayidx2.3 = getelementptr inbounds nuw i8, ptr %y, i64 3
+  %7 = load i8, ptr %arrayidx2.3
+  %conv3.3 = sext i8 %7 to i32
+  %mul.3 = mul nsw i32 %conv3.3, %conv.3
+  %sub.3 = sub nsw i32 %mul.3, %sub.2
+  %arrayidx.4 = getelementptr inbounds nuw i8, ptr %x, i64 4
+  %8 = load i8, ptr %arrayidx.4
+  %conv.4 = sext i8 %8 to i32
+  %arrayidx2.4 = getelementptr inbounds nuw i8, ptr %y, i64 4
+  %9 = load i8, ptr %arrayidx2.4
+  %conv3.4 = sext i8 %9 to i32
+  %mul.4 = mul nsw i32 %conv3.4, %conv.4
+  %sub.4 = sub nsw i32 %mul.4, %sub.3
+  %arrayidx.5 = getelementptr inbounds nuw i8, ptr %x, i64 5
+  %10 = load i8, ptr %arrayidx.5
+  %conv.5 = sext i8 %10 to i32
+  %arrayidx2.5 = getelementptr inbounds nuw i8, ptr %y, i64 5
+  %11 = load i8, ptr %arrayidx2.5
+  %conv3.5 = sext i8 %11 to i32
+  %mul.5 = mul nsw i32 %conv3.5, %conv.5
+  %sub.5 = sub nsw i32 %mul.5, %sub.4
+  %arrayidx.6 = getelementptr inbounds nuw i8, ptr %x, i64 6
+  %12 = load i8, ptr %arrayidx.6
+  %conv.6 = sext i8 %12 to i32
+  %arrayidx2.6 = getelementptr inbounds nuw i8, ptr %y, i64 6
+  %13 = load i8, ptr %arrayidx2.6
+  %conv3.6 = sext i8 %13 to i32
+  %mul.6 = mul nsw i32 %conv3.6, %conv.6
+  %sub.6 = sub nsw i32 %mul.6, %sub.5
+  %arrayidx.7 = getelementptr inbounds nuw i8, ptr %x, i64 7
+  %14 = load i8, ptr %arrayidx.7
+  %conv.7 = sext i8 %14 to i32
+  %arrayidx2.7 = getelementptr inbounds nuw i8, ptr %y, i64 7
+  %15 = load i8, ptr %arrayidx2.7
+  %conv3.7 = sext i8 %15 to i32
+  %mul.7 = mul nsw i32 %conv3.7, %conv.7
+  %sub.7 = sub nsw i32 %mul.7, %sub.6
+  %arrayidx.8 = getelementptr inbounds nuw i8, ptr %x, i64 8
+  %16 = load i8, ptr %arrayidx.8
+  %conv.8 = sext i8 %16 to i32
+  %arrayidx2.8 = getelementptr inbounds nuw i8, ptr %y, i64 8
+  %17 = load i8, ptr %arrayidx2.8
+  %conv3.8 = sext i8 %17 to i32
+  %mul.8 = mul nsw i32 %conv3.8, %conv.8
+  %sub.8 = sub nsw i32 %mul.8, %sub.7
+  %arrayidx.9 = getelementptr inbounds nuw i8, ptr %x, i64 9
+  %18 = load i8, ptr %arrayidx.9
+  %conv.9 = sext i8 %18 to i32
+  %arrayidx2.9 = getelementptr inbounds nuw i8, ptr %y, i64 9
+  %19 = load i8, ptr %arrayidx2.9
+  %conv3.9 = sext i8 %19 to i32
+  %mul.9 = mul nsw i32 %conv3.9, %conv.9
+  %sub.9 = sub nsw i32 %mul.9, %sub.8
+  %arrayidx.10 = getelementptr inbounds nuw i8, ptr %x, i64 10
+  %20 = load i8, ptr %arrayidx.10
+  %conv.10 = sext i8 %20 to i32
+  %arrayidx2.10 = getelementptr inbounds nuw i8, ptr %y, i64 10
+  %21 = load i8, ptr %arrayidx2.10
+  %conv3.10 = sext i8 %21 to i32
+  %mul.10 = mul nsw i32 %conv3.10, %conv.10
+  %sub.10 = sub nsw i32 %mul.10, %sub.9
+  %arrayidx.11 = getelementptr inbounds nuw i8, ptr %x, i64 11
+  %22 = load i8, ptr %arrayidx.11
+  %conv.11 = sext i8 %22 to i32
+  %arrayidx2.11 = getelementptr inbounds nuw i8, ptr %y, i64 11
+  %23 = load i8, ptr %arrayidx2.11
+  %conv3.11 = sext i8 %23 to i32
+  %mul.11 = mul nsw i32 %conv3.11, %conv.11
+  %sub.11 = sub nsw i32 %mul.11, %sub.10
+  %arrayidx.12 = getelementptr inbounds nuw i8, ptr %x, i64 12
+  %24 = load i8, ptr %arrayidx.12
+  %conv.12 = sext i8 %24 to i32
+  %arrayidx2.12 = getelementptr inbounds nuw i8, ptr %y, i64 12
+  %25 = load i8, ptr %arrayidx2.12
+  %conv3.12 = sext i8 %25 to i32
+  %mul.12 = mul nsw i32 %conv3.12, %conv.12
+  %sub.12 = sub nsw i32 %mul.12, %sub.11
+  %arrayidx.13 = getelementptr inbounds nuw i8, ptr %x, i64 13
+  %26 = load i8, ptr %arrayidx.13
+  %conv.13 = sext i8 %26 to i32
+  %arrayidx2.13 = getelementptr inbounds nuw i8, ptr %y, i64 13
+  %27 = load i8, ptr %arrayidx2.13
+  %conv3.13 = sext i8 %27 to i32
+  %mul.13 = mul nsw i32 %conv3.13, %conv.13
+  %sub.13 = sub nsw i32 %mul.13, %sub.12
+  %arrayidx.14 = getelementptr inbounds nuw i8, ptr %x, i64 14
+  %28 = load i8, ptr %arrayidx.14
+  %conv.14 = sext i8 %28 to i32
+  %arrayidx2.14 = getelementptr inbounds nuw i8, ptr %y, i64 14
+  %29 = load i8, ptr %arrayidx2.14
+  %conv3.14 = sext i8 %29 to i32
+  %mul.14 = mul nsw i32 %conv3.14, %conv.14
+  %sub.14 = sub nsw i32 %mul.14, %sub.13
+  %arrayidx.15 = getelementptr inbounds nuw i8, ptr %x, i64 15
+  %30 = load i8, ptr %arrayidx.15
+  %conv.15 = sext i8 %30 to i32
+  %arrayidx2.15 = getelementptr inbounds nuw i8, ptr %y, i64 15
+  %31 = load i8, ptr %arrayidx2.15
+  %conv3.15 = sext i8 %31 to i32
+  %mul.15 = mul nsw i32 %conv3.15, %conv.15
+  %sub.15 = sub nsw i32 %mul.15, %sub.14
+  ret i32 %sub.15
+}

>From 33236a3db0356823ee1919e27cf6c89c45f56c9e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 29 Aug 2025 12:53:38 +0100
Subject: [PATCH 08/10] Rebase

---
 .../LoopVectorize/vplan-printing-reductions.ll        | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index c37113e517cb3..2ffb8203d49dd 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -530,13 +530,14 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT: Live-in ir<1024> = original trip-count
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<entry>:
-; CHECK-NEXT: Successor(s): ir-bb<scalar.ph>, ir-bb<vector.ph>
+; CHECK-NEXT:  EMIT branch-on-cond ir<false>
+; CHECK-NEXT: Successor(s): ir-bb<scalar.ph>, vector.ph
 ; CHECK-EMPTY:
-; CHECK-NEXT: ir-bb<vector.ph>:
+; CHECK-NEXT: vector.ph:
 ; CHECK-NEXT: Successor(s): vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: vector.body:
-; CHECK-NEXT:   EMIT-SCALAR vp<%index> = phi [ ir<0>, ir-bb<vector.ph> ], [ vp<%index.next>, vector.body ]
+; CHECK-NEXT:   EMIT-SCALAR vp<%index> = phi [ ir<0>, vector.ph ], [ vp<%index.next>, vector.body ]
 ; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add>
 ; CHECK-NEXT:   CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%index>
 ; CHECK-NEXT:   WIDEN ir<%load.a> = load ir<%gep.a>
@@ -551,11 +552,11 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<%2> = compute-reduction-result ir<%accum>, ir<%add>
+; CHECK-NEXT:   EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, ir<%add>
 ; CHECK-NEXT: Successor(s): ir-bb<exit>
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<exit>:
-; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%2> from middle.block)
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<[[RED_RESULT]]> from middle.block)
 ; CHECK-NEXT: No successors
 ; CHECK-EMPTY:
 ; CHECK-NEXT: ir-bb<scalar.ph>:

>From b988436d81533af156fb63b78c371a7a371597e8 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 1 Sep 2025 15:05:37 +0100
Subject: [PATCH 09/10] Address approval review

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h  |  4 ++--
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 10 ++++------
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index b84f9d7775f4e..af78e0c1e4799 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1649,8 +1649,8 @@ class TargetTransformInfo {
   /// Calculate the cost of an extended reduction pattern, similar to
   /// getArithmeticReductionCost of an Add/Sub reduction with multiply and
   /// optional extensions. This is the cost of as:
-  /// ResTy vecreduce.add/sub(mul (A, B)).
-  /// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
+  /// * ResTy vecreduce.add/sub(mul (A, B)) or,
+  /// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
   LLVM_ABI InstructionCost getMulAccReductionCost(
       bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 2b377e4010b99..c31f67891b38d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3159,8 +3159,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   // Clamp the range if using multiply-accumulate-reduction is profitable.
   auto IsMulAccValidAndClampRange =
       [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
-          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
-          unsigned Opcode) -> bool {
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
@@ -3202,12 +3201,11 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   MulR, RecipeA, RecipeB, nullptr, Opcode)) {
+                                   MulR, RecipeA, RecipeB, nullptr)) {
       return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
     }
     // Match reduce.add(mul).
-    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr,
-                                   Opcode))
+    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr))
       return new VPExpressionRecipe(MulR, Red);
   }
   // Match reduce.add(ext(mul(ext(A), ext(B)))).
@@ -3225,7 +3223,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         Ext0->getOpcode() == Ext1->getOpcode() &&
         IsMulAccValidAndClampRange(Ext0->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   Mul, Ext0, Ext1, Ext, Opcode)) {
+                                   Mul, Ext0, Ext1, Ext)) {
       auto *NewExt0 = new VPWidenCastRecipe(
           Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
           Ext0->getDebugLoc());

>From 0c55abcd853562e9c6535cb7a5ea530ef5470de9 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 1 Sep 2025 16:25:37 +0100
Subject: [PATCH 10/10] Remove NFCs

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c31f67891b38d..362480a923b6f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3158,7 +3158,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
 
   // Clamp the range if using multiply-accumulate-reduction is profitable.
   auto IsMulAccValidAndClampRange =
-      [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
           VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
@@ -3167,7 +3167,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
           InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
-              IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
+              isZExt, Opcode, RedTy, SrcVecTy, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);
           InstructionCost ExtCost = 0;
@@ -3192,7 +3192,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     auto *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    auto *MulR = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
 
     // Match reduce.add(mul(ext, ext)).
     if (RecipeA && RecipeB &&
@@ -3201,12 +3201,12 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   MulR, RecipeA, RecipeB, nullptr)) {
-      return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
+                                   Mul, RecipeA, RecipeB, nullptr)) {
+      return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
     }
     // Match reduce.add(mul).
-    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr))
-      return new VPExpressionRecipe(MulR, Red);
+    if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+      return new VPExpressionRecipe(Mul, Red);
   }
   // Match reduce.add(ext(mul(ext(A), ext(B)))).
   // All extend recipes must have same opcode or A == B



More information about the llvm-commits mailing list