[llvm-branch-commits] [llvm] [LoopVectorizer] Bundle partial reductions with different extensions (PR #136997)

Sam Tebbs via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue May 20 01:21:44 PDT 2025


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/136997

>From 10c4727074a7f5b4502ad08dc655be8fa5ffa3d2 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 23 Apr 2025 13:16:38 +0100
Subject: [PATCH 1/5] [LoopVectorizer] Bundle partial reductions with different
 extensions

This PR adds support for extensions of different signedness to
VPMulAccumulateReductionRecipe and allows such partial reductions to be
bundled into that class.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 42 +++++++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 27 ++++++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 25 ++++-----
 .../partial-reduce-dot-product-mixed.ll       | 56 +++++++++----------
 .../LoopVectorize/AArch64/vplan-printing.ll   | 29 +++++-----
 5 files changed, 99 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 20d272e69e6e7..e11f608d068da 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2493,11 +2493,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// recipe is abstract and needs to be lowered to concrete recipes before
 /// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
-  /// Opcode of the extend recipe.
-  Instruction::CastOps ExtOp;
+  /// Opcodes of the extend recipes.
+  Instruction::CastOps ExtOp0;
+  Instruction::CastOps ExtOp1;
 
-  /// Non-neg flag of the extend recipe.
-  bool IsNonNeg = false;
+  /// Non-neg flags of the extend recipe.
+  bool IsNonNeg0 = false;
+  bool IsNonNeg1 = false;
 
   Type *ResultTy;
 
@@ -2512,7 +2514,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             MulAcc->getCondOp(), MulAcc->isOrdered(),
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
-        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
+        IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
         ResultTy(MulAcc->getResultType()),
         IsPartialReduction(MulAcc->isPartialReduction()) {}
 
@@ -2526,7 +2529,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
-        ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
+        ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
+        IsNonNeg0(Ext0->isNonNeg()), IsNonNeg1(Ext1->isNonNeg()),
         ResultTy(ResultTy),
         IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
@@ -2542,7 +2546,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
-        ExtOp(Instruction::CastOps::CastOpsEnd) {
+        ExtOp0(Instruction::CastOps::CastOpsEnd),
+        ExtOp1(Instruction::CastOps::CastOpsEnd) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateReductionRecipe must be "
@@ -2586,19 +2591,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   VPValue *getVecOp1() const { return getOperand(2); }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+  bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
 
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
 
-  /// Return the opcode of the underlying extend.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+  /// Return the opcode of the underlying extends.
+  Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
+  Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
+
+  /// Return if the first extend's opcode is ZExt.
+  bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
+
+  /// Return if the second extend's opcode is ZExt.
+  bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
 
-  /// Return if the extend opcode is ZExt.
-  bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+  /// Return the non negative flag of the first ext recipe.
+  bool isNonNeg0() const { return IsNonNeg0; }
 
-  /// Return the non negative flag of the ext recipe.
-  bool isNonNeg() const { return IsNonNeg; }
+  /// Return the non negative flag of the second ext recipe.
+  bool isNonNeg1() const { return IsNonNeg1; }
 
   /// Return if the underlying reduction recipe is a partial reduction.
   bool isPartialReduction() const { return IsPartialReduction; }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index bdc1d49ec88d9..53698fe15d4f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2438,14 +2438,14 @@ VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
     return Ctx.TTI.getPartialReductionCost(
         Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
         Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
-        TTI::getPartialReductionExtendKind(getExtOpcode()),
-        TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
+        TTI::getPartialReductionExtendKind(getExt0Opcode()),
+        TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
   }
 
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
+  return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
                                         Ctx.CostKind);
 }
 
@@ -2530,13 +2530,24 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
-  if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
-  else
+  if (isExtended()) {
+    O << " ";
+    if (isZExt0())
+      O << "zero-";
+    else
+      O << "sign-";
+    O << "extended to " << *getResultType() << "), (";
+  } else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
-  if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+  if (isExtended()) {
+    O << " ";
+    if (isZExt1())
+      O << "zero-";
+    else
+      O << "sign-";
+    O << "extended to " << *getResultType() << ")";
+  }
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7a8cbd908c795..f305e09396c1c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2121,12 +2121,12 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   VPValue *Op0, *Op1;
   if (MulAcc->isExtended()) {
     Type *RedTy = MulAcc->getResultType();
-    if (MulAcc->isZExt())
-      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                  RedTy, MulAcc->isNonNeg(),
+    if (MulAcc->isZExt0())
+      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->isNonNeg0(),
                                   MulAcc->getDebugLoc());
     else
-      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
                                   RedTy, MulAcc->getDebugLoc());
     Op0->getDefiningRecipe()->insertBefore(MulAcc);
     // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
@@ -2134,13 +2134,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
     if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
       Op1 = Op0;
     } else {
-      if (MulAcc->isZExt())
-        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                                    RedTy, MulAcc->isNonNeg(),
-                                    MulAcc->getDebugLoc());
+      if (MulAcc->isZExt1())
+        Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
+                                    MulAcc->getVecOp1(), RedTy,
+                                    MulAcc->isNonNeg1(), MulAcc->getDebugLoc());
       else
-        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                                    RedTy, MulAcc->getDebugLoc());
+        Op1 =
+            new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
+                                  RedTy, MulAcc->getDebugLoc());
       Op1->getDefiningRecipe()->insertBefore(MulAcc);
     }
   } else {
@@ -2451,10 +2452,8 @@ tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
 
   auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
   auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
-  // TODO: Make work with extends of different signedness
   if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
-      Ext1->hasMoreThanOneUniqueUser() ||
-      Ext0->getOpcode() != Ext1->getOpcode())
+      Ext1->hasMoreThanOneUniqueUser())
     return;
 
   auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index f581b6f384bc8..6e1dc7230205b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -22,19 +22,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -60,19 +60,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -121,19 +121,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -159,19 +159,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 4dc83ed8a95b5..3911e03de2f50 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -6,7 +6,7 @@ target triple = "aarch64-none-unknown-elf"
 
 ; Tests for printing VPlans that are enabled under AArch64
 
-define i32 @print_partial_reduction(ptr %a, ptr %b) {
+define i32 @print_partial_reduction_sext_zext(ptr %a, ptr %b) {
 ; CHECK:      VPlan 'Initial VPlan for VF={8,16},UF>=1' {
 ; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF
 ; CHECK-NEXT: Live-in vp<[[VEC_TC:%.+]]> = vector-trip-count
@@ -21,18 +21,15 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT: vector.body:
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[CAN_IV_NEXT:%.+]]>
-; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<[[REDUCE:%.+]]> (VF scaled by 1/4)
+; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, vp<[[REDUCE:%.+]]> (VF scaled by 1/4)
 ; CHECK-NEXT:   vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[CAN_IV]]>, ir<1>
 ; CHECK-NEXT:   CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[STEPS]]>
 ; CHECK-NEXT:   vp<[[PTR_A:%.+]]> = vector-pointer ir<%gep.a>
 ; CHECK-NEXT:   WIDEN ir<%load.a> = load vp<[[PTR_A]]>
-; CHECK-NEXT:   WIDEN-CAST ir<%ext.a> = sext ir<%load.a> to i32
 ; CHECK-NEXT:   CLONE ir<%gep.b> = getelementptr ir<%b>, vp<[[STEPS]]>
 ; CHECK-NEXT:   vp<[[PTR_B:%.+]]> = vector-pointer ir<%gep.b>
 ; CHECK-NEXT:   WIDEN ir<%load.b> = load vp<[[PTR_B]]>
-; CHECK-NEXT:   WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32
-; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
-; CHECK-NEXT:   PARTIAL-REDUCE ir<[[REDUCE]]> = add ir<[[ACC]]>, ir<%mul>
+; CHECK-NEXT:   MULACC-REDUCE vp<[[REDUCE]]> = ir<%accum> + partial.reduce.add (mul (ir<%load.b> zero-extended to i32), (ir<%load.a> sign-extended to i32))
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
 ; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]>
 ; CHECK-NEXT: No successors
@@ -40,7 +37,7 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<[[ACC]]>, ir<[[REDUCE]]>
+; CHECK-NEXT:   EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<[[ACC]]>, vp<[[REDUCE]]>
 ; CHECK-NEXT:   EMIT vp<[[EXTRACT:%.+]]> = extract-from-end vp<[[RED_RESULT]]>, ir<1>
 ; CHECK-NEXT:   EMIT vp<[[CMP:%.+]]> = icmp eq ir<1024>, vp<%1>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<[[CMP]]>
@@ -87,18 +84,18 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT:   vector.body:
 ; CHECK-NEXT:     EMIT vp<[[EP_IV:%.+]]> = phi ir<0>, vp<%index.next>
-; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add> (VF scaled by 1/4)
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, vp<[[REDUCE:%.+]]> (VF scaled by 1/4)
 ; CHECK-NEXT:     vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[EP_IV]]>, ir<1>
 ; CHECK-NEXT:     CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[STEPS]]>
 ; CHECK-NEXT:     vp<[[PTR_A:%.+]]> = vector-pointer ir<%gep.a>
 ; CHECK-NEXT:     WIDEN ir<%load.a> = load vp<[[PTR_A]]>
-; CHECK-NEXT:     WIDEN-CAST ir<%ext.a> = sext ir<%load.a> to i32
 ; CHECK-NEXT:     CLONE ir<%gep.b> = getelementptr ir<%b>, vp<[[STEPS]]>
 ; CHECK-NEXT:     vp<[[PTR_B:%.+]]> = vector-pointer ir<%gep.b>
 ; CHECK-NEXT:     WIDEN ir<%load.b> = load vp<[[PTR_B]]>
-; CHECK-NEXT:     WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32
-; CHECK-NEXT:     WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
-; CHECK-NEXT:     PARTIAL-REDUCE ir<%add> = add ir<%accum>, ir<%mul>
+; CHECK-NEXT:     WIDEN-CAST vp<[[EXTB:%.+]]> = zext ir<%load.b> to i32
+; CHECK-NEXT:     WIDEN-CAST vp<[[EXTA:%.+]]> = sext ir<%load.a> to i32
+; CHECK-NEXT:     WIDEN vp<[[MUL:%.+]]> = mul vp<[[EXTB]]>, vp<[[EXTA]]>
+; CHECK-NEXT:     PARTIAL-REDUCE vp<[[REDUCE]]> = add ir<%accum>, vp<[[MUL]]>
 ; CHECK-NEXT:     EMIT vp<[[EP_IV_NEXT:%.+]]> = add nuw vp<[[EP_IV]]>, ir<16>
 ; CHECK-NEXT:     EMIT branch-on-count vp<[[EP_IV_NEXT]]>, ir<1024>
 ; CHECK-NEXT:   No successors
@@ -106,7 +103,7 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: Successor(s): middle.block
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
-; CHECK-NEXT:   EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, ir<%add>
+; CHECK-NEXT:   EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, vp<[[REDUCE]]>
 ; CHECK-NEXT:   EMIT vp<[[EXTRACT:%.+]]> = extract-from-end vp<[[RED_RESULT]]>, ir<1>
 ; CHECK-NEXT:   EMIT vp<[[CMP:%.+]]> = icmp eq ir<1024>, ir<1024>
 ; CHECK-NEXT:   EMIT branch-on-cond vp<[[CMP]]>
@@ -180,7 +177,7 @@ define i32 @print_bundled_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT:   CLONE ir<%gep.b> = getelementptr ir<%b>, vp<[[STEPS]]>
 ; CHECK-NEXT:   vp<[[PTR_B:%.+]]> = vector-pointer ir<%gep.b>
 ; CHECK-NEXT:   WIDEN ir<%load.b> = load vp<[[PTR_B]]>
-; CHECK-NEXT:   MULACC-REDUCE vp<[[REDUCE:%.+]]> = ir<%accum> + partial.reduce.add (mul (ir<%load.b> extended to i32), (ir<%load.a> extended to i32))
+; CHECK-NEXT:   MULACC-REDUCE vp<[[REDUCE:%.+]]> = ir<%accum> + partial.reduce.add (mul (ir<%load.b> zero-extended to i32), (ir<%load.a> zero-extended to i32))
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
 ; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]>
 ; CHECK-NEXT: No successors
@@ -235,7 +232,7 @@ define i32 @print_bundled_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT:   vector.body:
 ; CHECK-NEXT:     EMIT vp<[[EP_IV:%.+]]> = phi ir<0>, vp<%index.next>
-; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, vp<%7> (VF scaled by 1/4)
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, vp<[[REDUCE:%.+]]> (VF scaled by 1/4)
 ; CHECK-NEXT:     vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[EP_IV]]>, ir<1>
 ; CHECK-NEXT:     CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[STEPS]]>
 ; CHECK-NEXT:     vp<[[PTR_A:%.+]]> = vector-pointer ir<%gep.a>
@@ -246,7 +243,7 @@ define i32 @print_bundled_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT:     WIDEN-CAST vp<%4> = zext ir<%load.b> to i32
 ; CHECK-NEXT:     WIDEN-CAST vp<%5> = zext ir<%load.a> to i32
 ; CHECK-NEXT:     WIDEN vp<%6> = mul vp<%4>, vp<%5>
-; CHECK-NEXT:     PARTIAL-REDUCE vp<[[REDUCE:%.+]]> = add ir<%accum>, vp<%6>
+; CHECK-NEXT:     PARTIAL-REDUCE vp<[[REDUCE]]> = add ir<%accum>, vp<%6>
 ; CHECK-NEXT:     EMIT vp<[[EP_IV_NEXT:%.+]]> = add nuw vp<[[EP_IV]]>, ir<16>
 ; CHECK-NEXT:     EMIT branch-on-count vp<[[EP_IV_NEXT]]>, ir<1024>
 ; CHECK-NEXT:   No successors

>From 40e44cf2ff40ba651b64e84d9fefd8c496e5ffa9 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Mon, 28 Apr 2025 11:44:20 +0100
Subject: [PATCH 2/5] Create VecOperandInfo

---
 llvm/lib/Transforms/Vectorize/VPlan.h         | 53 ++++++++-----------
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 17 +++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 27 +++++-----
 3 files changed, 47 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e11f608d068da..c3839866d4c27 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2493,13 +2493,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// recipe is abstract and needs to be lowered to concrete recipes before
 /// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
-  /// Opcodes of the extend recipes.
-  Instruction::CastOps ExtOp0;
-  Instruction::CastOps ExtOp1;
-
-  /// Non-neg flags of the extend recipe.
-  bool IsNonNeg0 = false;
-  bool IsNonNeg1 = false;
 
   Type *ResultTy;
 
@@ -2514,10 +2507,11 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             MulAcc->getCondOp(), MulAcc->isOrdered(),
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
-        ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
-        IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
         ResultTy(MulAcc->getResultType()),
-        IsPartialReduction(MulAcc->isPartialReduction()) {}
+        IsPartialReduction(MulAcc->isPartialReduction()) {
+    VecOpInfo[0] = MulAcc->getVecOp0Info();
+    VecOpInfo[1] = MulAcc->getVecOp1Info();
+  }
 
 public:
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2529,14 +2523,14 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
-        ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
-        IsNonNeg0(Ext0->isNonNeg()), IsNonNeg1(Ext1->isNonNeg()),
         ResultTy(ResultTy),
         IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateteReductionRecipe must "
            "be Add");
+    VecOpInfo[0] = {Ext0->getOpcode(), Ext0->isNonNeg()};
+    VecOpInfo[1] = {Ext1->getOpcode(), Ext1->isNonNeg()};
   }
 
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
@@ -2545,15 +2539,20 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             {R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
-            R->getDebugLoc()),
-        ExtOp0(Instruction::CastOps::CastOpsEnd),
-        ExtOp1(Instruction::CastOps::CastOpsEnd) {
+            R->getDebugLoc()) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateReductionRecipe must be "
            "Add");
   }
 
+  struct VecOperandInfo {
+    /// The operand's extend opcode.
+    Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
+    /// Non-neg portion of the operand's flags.
+    bool IsNonNeg = false;
+  };
+
   ~VPMulAccumulateReductionRecipe() override = default;
 
   VPMulAccumulateReductionRecipe *clone() override {
@@ -2591,29 +2590,21 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   VPValue *getVecOp1() const { return getOperand(2); }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
+  bool isExtended() const {
+    return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd;
+  }
 
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
 
-  /// Return the opcode of the underlying extends.
-  Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
-  Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
-
-  /// Return if the first extend's opcode is ZExt.
-  bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
-
-  /// Return if the second extend's opcode is ZExt.
-  bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
-
-  /// Return the non negative flag of the first ext recipe.
-  bool isNonNeg0() const { return IsNonNeg0; }
-
-  /// Return the non negative flag of the second ext recipe.
-  bool isNonNeg1() const { return IsNonNeg1; }
+  VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; }
+  VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; }
 
   /// Return if the underlying reduction recipe is a partial reduction.
   bool isPartialReduction() const { return IsPartialReduction; }
+
+protected:
+  VecOperandInfo VecOpInfo[2];
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 53698fe15d4f8..d805f6c254b5b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2434,19 +2434,22 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  VecOperandInfo Op0Info = getVecOp0Info();
+  VecOperandInfo Op1Info = getVecOp1Info();
   if (isPartialReduction()) {
     return Ctx.TTI.getPartialReductionCost(
         Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
         Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
-        TTI::getPartialReductionExtendKind(getExt0Opcode()),
-        TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
+        TTI::getPartialReductionExtendKind(Op0Info.ExtOp),
+        TTI::getPartialReductionExtendKind(Op1Info.ExtOp), Instruction::Mul);
   }
 
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
-                                        Ctx.CostKind);
+  return Ctx.TTI.getMulAccReductionCost(Op0Info.ExtOp ==
+                                            Instruction::CastOps::ZExt,
+                                        RedTy, SrcVecTy, Ctx.CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2514,6 +2517,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 
 void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                                            VPSlotTracker &SlotTracker) const {
+  VecOperandInfo Op0Info = getVecOp0Info();
+  VecOperandInfo Op1Info = getVecOp1Info();
   O << Indent << "MULACC-REDUCE ";
   printAsOperand(O, SlotTracker);
   O << " = ";
@@ -2532,7 +2537,7 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   getVecOp0()->printAsOperand(O, SlotTracker);
   if (isExtended()) {
     O << " ";
-    if (isZExt0())
+    if (Op0Info.ExtOp == Instruction::CastOps::ZExt)
       O << "zero-";
     else
       O << "sign-";
@@ -2542,7 +2547,7 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   getVecOp1()->printAsOperand(O, SlotTracker);
   if (isExtended()) {
     O << " ";
-    if (isZExt1())
+    if (Op1Info.ExtOp == Instruction::CastOps::ZExt)
       O << "zero-";
     else
       O << "sign-";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index f305e09396c1c..479852f119b57 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2120,28 +2120,29 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
   VPValue *Op0, *Op1;
   if (MulAcc->isExtended()) {
+    VPMulAccumulateReductionRecipe::VecOperandInfo Op0Info =
+        MulAcc->getVecOp0Info();
+    VPMulAccumulateReductionRecipe::VecOperandInfo Op1Info =
+        MulAcc->getVecOp1Info();
     Type *RedTy = MulAcc->getResultType();
-    if (MulAcc->isZExt0())
-      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
-                                  RedTy, MulAcc->isNonNeg0(),
-                                  MulAcc->getDebugLoc());
+    if (Op0Info.ExtOp == Instruction::CastOps::ZExt)
+      Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
+                                  Op0Info.IsNonNeg, MulAcc->getDebugLoc());
     else
-      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
-                                  RedTy, MulAcc->getDebugLoc());
+      Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
+                                  MulAcc->getDebugLoc());
     Op0->getDefiningRecipe()->insertBefore(MulAcc);
     // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
     // VPWidenCastRecipe.
     if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
       Op1 = Op0;
     } else {
-      if (MulAcc->isZExt1())
-        Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
-                                    MulAcc->getVecOp1(), RedTy,
-                                    MulAcc->isNonNeg1(), MulAcc->getDebugLoc());
+      if (Op1Info.ExtOp == Instruction::CastOps::ZExt)
+        Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
+                                    Op1Info.IsNonNeg, MulAcc->getDebugLoc());
       else
-        Op1 =
-            new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
-                                  RedTy, MulAcc->getDebugLoc());
+        Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
+                                    MulAcc->getDebugLoc());
       Op1->getDefiningRecipe()->insertBefore(MulAcc);
     }
   } else {

>From d0c03435bec9d79d0b6a14b8dac2e0187cf0c7d7 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 8 May 2025 11:42:07 +0100
Subject: [PATCH 3/5] Correct printing test

---
 llvm/test/Transforms/LoopVectorize/vplan-printing.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 9bcea72633d3a..9cc49b17c9d93 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -1312,7 +1312,7 @@ define i64 @print_mulacc_extended(ptr nocapture readonly %x, ptr nocapture reado
 ; CHECK-NEXT:     CLONE ir<%arrayidx1> = getelementptr inbounds ir<%y>, vp<%3>
 ; CHECK-NEXT:     vp<%5> = vector-pointer ir<%arrayidx1>
 ; CHECK-NEXT:     WIDEN ir<%load1> = load vp<%5>
-; CHECK-NEXT:     MULACC-REDUCE vp<%6> = ir<%rdx> + reduce.add (mul nsw (ir<%load0> extended to i64), (ir<%load1> extended to i64))
+; CHECK-NEXT:     MULACC-REDUCE vp<%6> = ir<%rdx> + reduce.add (mul nsw (ir<%load0> sign-extended to i64), (ir<%load1> sign-extended to i64))
 ; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%2>, vp<%0>
 ; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%1>
 ; CHECK-NEXT:   No successors

>From 6e82fc5e92ac7501115c35a419d3596016924962 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 May 2025 14:05:48 +0100
Subject: [PATCH 4/5] Return reference from getVecOpXInfo

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c3839866d4c27..6743076812a0e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2508,10 +2508,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
         ResultTy(MulAcc->getResultType()),
-        IsPartialReduction(MulAcc->isPartialReduction()) {
-    VecOpInfo[0] = MulAcc->getVecOp0Info();
-    VecOpInfo[1] = MulAcc->getVecOp1Info();
-  }
+        IsPartialReduction(MulAcc->isPartialReduction()),
+        VecOpInfo{MulAcc->getVecOp0Info(), MulAcc->getVecOp1Info()} {}
 
 public:
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2524,13 +2522,13 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
         ResultTy(ResultTy),
-        IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
+        IsPartialReduction(isa<VPPartialReductionRecipe>(R)),
+        VecOpInfo{{Ext0->getOpcode(), Ext0->isNonNeg()},
+                  {Ext1->getOpcode(), Ext1->isNonNeg()}} {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateteReductionRecipe must "
            "be Add");
-    VecOpInfo[0] = {Ext0->getOpcode(), Ext0->isNonNeg()};
-    VecOpInfo[1] = {Ext1->getOpcode(), Ext1->isNonNeg()};
   }
 
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
@@ -2597,8 +2595,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   /// Return if the operands of mul instruction come from same extend.
   bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
 
-  VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; }
-  VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; }
+  const VecOperandInfo &getVecOp0Info() const { return VecOpInfo[0]; }
+  const VecOperandInfo &getVecOp1Info() const { return VecOpInfo[1]; }
 
   /// Return if the underlying reduction recipe is a partial reduction.
   bool isPartialReduction() const { return IsPartialReduction; }

>From d604652fbef1b0dabd0fb32067d4461fe03a2f11 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 19 May 2025 16:48:02 +0100
Subject: [PATCH 5/5] Also check other op info in isExtended()

---
 llvm/lib/Transforms/Vectorize/VPlan.h | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 6743076812a0e..a2a49b5928587 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2589,7 +2589,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
 
   /// Return if this MulAcc recipe contains extend instructions.
   bool isExtended() const {
-    return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd;
+    return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd ||
+           getVecOp1Info().ExtOp != Instruction::CastOps::CastOpsEnd;
   }
 
   /// Return if the operands of mul instruction come from same extend.



More information about the llvm-branch-commits mailing list