[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 16 09:04:49 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 506c84a7198630b7476b02d985c6ed09338f757d 1640de07bf8facb2b5284d88481a96cd026e3613 -- llvm/include/llvm/IR/DerivedTypes.h llvm/include/llvm/IR/Intrinsics.h llvm/lib/IR/Function.cpp llvm/lib/Transforms/Vectorize/LoopVectorize.cpp llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h llvm/lib/Transforms/Vectorize/VPlan.h llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp llvm/lib/Transforms/Vectorize/VPlanAnalysis.h llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp llvm/lib/Transforms/Vectorize/VPlanValue.h
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 866a01c9af..1c045ccfb8 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,8 +512,8 @@ public:
EltCnt.divideCoefficientBy(2));
}
- /// This static method returns a VectorType with quarter as many elements as the
- /// input type and the same element type.
+ /// This static method returns a VectorType with quarter as many elements as
+ /// the input type and the same element type.
static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
auto EltCnt = VTy->getElementCount();
assert(EltCnt.isKnownEven() &&
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index e03e7e0bf5..a5c53aadd5 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -161,16 +161,16 @@ namespace Intrinsic {
unsigned getArgumentNumber() const {
assert(Kind == Argument || Kind == ExtendArgument ||
- Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
- Kind == SameVecWidthArgument || Kind == VecElementArgument ||
- Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
- Kind == VecOfBitcastsToInt);
+ Kind == TruncArgument || Kind == HalfVecArgument ||
+ Kind == QuarterVecArgument || Kind == SameVecWidthArgument ||
+ Kind == VecElementArgument || Kind == Subdivide2Argument ||
+ Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
return Argument_Info >> 3;
}
ArgKind getArgumentKind() const {
assert(Kind == Argument || Kind == ExtendArgument ||
- Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
- Kind == SameVecWidthArgument ||
+ Kind == TruncArgument || Kind == HalfVecArgument ||
+ Kind == QuarterVecArgument || Kind == SameVecWidthArgument ||
Kind == VecElementArgument || Kind == Subdivide2Argument ||
Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
return (ArgKind)(Argument_Info & 7);
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index e9eebd5e35..59b866d70e 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1242,8 +1242,8 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
}
case IIT_QUARTER_VEC_ARG: {
unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
- OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
- ArgInfo));
+ OutputTable.push_back(
+ IITDescriptor::get(IITDescriptor::QuarterVecArgument, ArgInfo));
return;
}
case IIT_SAME_VEC_WIDTH_ARG: {
@@ -1410,8 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
case IITDescriptor::HalfVecArgument:
return VectorType::getHalfElementsVectorType(cast<VectorType>(
Tys[D.getArgumentNumber()]));
- case IITDescriptor::QuarterVecArgument: {
- return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+ case IITDescriptor::QuarterVecArgument: {
+ return VectorType::getQuarterElementsVectorType(
+ cast<VectorType>(Tys[D.getArgumentNumber()]));
}
case IITDescriptor::SameVecWidthArgument: {
Type *EltTy = DecodeFixedType(Infos, Tys, Context);
@@ -1629,11 +1630,11 @@ static bool matchIntrinsicType(
VectorType::getHalfElementsVectorType(
cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
case IITDescriptor::QuarterVecArgument: {
- if (D.getArgumentNumber() >= ArgTys.size())
+ if (D.getArgumentNumber() >= ArgTys.size())
return IsDeferredCheck || DeferCheck(Ty);
return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
VectorType::getQuarterElementsVectorType(
- cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+ cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
}
case IITDescriptor::SameVecWidthArgument: {
if (D.getArgumentNumber() >= ArgTys.size()) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1f37df061b..fafece8a04 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,7 +2203,8 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
}
-static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+static void getPartialReductionInstrChain(Instruction *Instr,
+ SmallVector<Value *, 4> &Chain) {
Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
@@ -2214,7 +2215,6 @@ static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*
Chain.push_back(Instr->getOperand(1));
}
-
/// @param Instr The root instruction to scan
static bool isInstrPartialReduction(Instruction *Instr) {
Value *ExpectedPhi;
@@ -2223,27 +2223,23 @@ static bool isInstrPartialReduction(Instruction *Instr) {
using namespace llvm::PatternMatch;
auto Pattern = m_Add(
- m_OneUse(m_Mul(
- m_OneUse(m_ZExt(
- m_OneUse(m_Load(
- m_GEP(
- m_Value(A),
- m_Value(InductionA)))))),
- m_OneUse(m_ZExt(
- m_OneUse(m_Load(
- m_GEP(
- m_Value(B),
- m_Value(InductionB))))))
- )), m_Value(ExpectedPhi));
+ m_OneUse(m_Mul(m_OneUse(m_ZExt(m_OneUse(
+ m_Load(m_GEP(m_Value(A), m_Value(InductionA)))))),
+ m_OneUse(m_ZExt(m_OneUse(
+ m_Load(m_GEP(m_Value(B), m_Value(InductionB)))))))),
+ m_Value(ExpectedPhi));
bool Matches = match(Instr, Pattern);
- if(!Matches)
+ if (!Matches)
return false;
- // Check that the two induction variable uses are to the same induction variable
- if(InductionA != InductionB) {
- LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+ // Check that the two induction variable uses are to the same induction
+ // variable
+ if (InductionA != InductionB) {
+ LLVM_DEBUG(
+ dbgs() << "Loop uses different induction variables for each input "
+ "variable, cannot create a partial reduction.\n");
return false;
}
@@ -2252,37 +2248,42 @@ static bool isInstrPartialReduction(Instruction *Instr) {
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
// Check that the extends extend to i32
- if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
- LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+ if (!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+ LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot "
+ "create a partial reduction.\n");
return false;
}
// Check that the loads are loading i8
LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
- if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
- LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+ if (!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+ LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a "
+ "partial reduction\n");
return false;
}
// Check that the add feeds into ExpectedPhi
PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
- if(!PhiNode) {
- LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+ if (!PhiNode) {
+ LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+ "partial reduction.\n");
return false;
}
// Check that the first phi value is a zero initializer
ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
- if(!ZeroInit || !ZeroInit->isZero()) {
- LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+ if (!ZeroInit || !ZeroInit->isZero()) {
+ LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot "
+ "create a partial reduction.\n");
return false;
}
// Check that the second phi value is the instruction we're looking at
Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
- if(!MaybeAdd || MaybeAdd != Instr) {
- LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+ if (!MaybeAdd || MaybeAdd != Instr) {
+ LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create "
+ "a partial reduction.\n");
return false;
}
@@ -5172,9 +5173,11 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
// Prevent epilogue vectorization if a partial reduction is involved
// TODO Is there a cleaner way to check this?
- if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
- return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
- }))
+ if (any_of(Legal->getReductionVars(),
+ [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+ return isInstrPartialReduction(
+ Reduction.second.getLoopExitInstr());
+ }))
return false;
// Epilogue vectorization code has not been auditted to ensure it handles
@@ -7277,13 +7280,16 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
}
// Ignore any values that we know will be flattened
- for(auto Reduction : this->Legal->getReductionVars()) {
+ for (auto Reduction : this->Legal->getReductionVars()) {
auto &Recurrence = Reduction.second;
- if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
- SmallVector<Value*, 4> PartialReductionValues;
- getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
- ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
- VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+ if (isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+ SmallVector<Value *, 4> PartialReductionValues;
+ getPartialReductionInstrChain(Recurrence.getLoopExitInstr(),
+ PartialReductionValues);
+ ValuesToIgnore.insert(PartialReductionValues.begin(),
+ PartialReductionValues.end());
+ VecValuesToIgnore.insert(PartialReductionValues.begin(),
+ PartialReductionValues.end());
}
}
}
@@ -8640,20 +8646,22 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
*CI);
}
- if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+ if (auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
return PartialReduce;
return tryToWiden(Instr, Operands, VPBB);
}
-VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
- VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+VPRecipeBase *
+VPRecipeBuilder::tryToCreatePartialReduction(VFRange &Range, Instruction *Instr,
+ ArrayRef<VPValue *> Operands) {
- if(isInstrPartialReduction(Instr)) {
+ if (isInstrPartialReduction(Instr)) {
auto EC = ElementCount::getScalable(16);
- if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+ if (std::find(Range.begin(), Range.end(), EC) == Range.end())
return nullptr;
- return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+ return new VPPartialReductionRecipe(
+ *Instr, make_range(Operands.begin(), Operands.end()));
}
return nullptr;
}
@@ -8865,7 +8873,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
VPBB->appendRecipe(Recipe);
}
- for(auto &Recipe : *VPBB)
+ for (auto &Recipe : *VPBB)
Recipe.postInsertionOp();
VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index c439f22170..cbd0809163 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,7 +116,8 @@ public:
ArrayRef<VPValue *> Operands,
VFRange &Range, VPBasicBlock *VPBB);
- VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+ VPRecipeBase *tryToCreatePartialReduction(VFRange &Range, Instruction *Instr,
+ ArrayRef<VPValue *> Operands);
/// Set the recipe created for given ingredient.
void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 5a572ecb79..ffe1608b46 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1887,7 +1887,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
unsigned VFScaleFactor = 1;
public:
-
/// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
/// RdxDesc.
VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
@@ -1902,9 +1901,9 @@ public:
~VPReductionPHIRecipe() override = default;
VPReductionPHIRecipe *clone() override {
- auto *R =
- new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
- *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
+ auto *R = new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()),
+ RdxDesc, *getOperand(0), IsInLoop,
+ IsOrdered, VFScaleFactor);
R->addOperand(getBackedgeValue());
return R;
}
@@ -1915,9 +1914,7 @@ public:
return R->getVPDefID() == VPDef::VPReductionPHISC;
}
- void SetVFScaleFactor(unsigned ScaleFactor) {
- VFScaleFactor = ScaleFactor;
- }
+ void SetVFScaleFactor(unsigned ScaleFactor) { VFScaleFactor = ScaleFactor; }
/// Generate the phi/select nodes.
void execute(VPTransformState &State) override;
@@ -1941,12 +1938,12 @@ public:
class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
unsigned Opcode;
+
public:
template <typename IterT>
- VPPartialReductionRecipe(Instruction &I,
- iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
- VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
- {}
+ VPPartialReductionRecipe(Instruction &I, iterator_range<IterT> Operands)
+ : VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, I),
+ Opcode(I.getOpcode()) {}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 8a75668886..11a95db3b8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -208,7 +208,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
llvm_unreachable("Unhandled opcode");
}
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+Type *
+VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
return R->getUnderlyingInstr()->getType();
}
@@ -242,7 +243,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
return inferScalarType(R->getOperand(0));
})
.Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
- VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
+ VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe,
+ VPPartialReductionRecipe>(
[this](const auto *R) { return inferScalarTypeForRecipe(R); })
.Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
// TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9aff5dd0a7..e305721d69 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -249,33 +249,33 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
State.setDebugLocFrom(getDebugLoc());
auto &Builder = State.Builder;
- switch(Opcode) {
+ switch (Opcode) {
case Instruction::Add: {
for (unsigned Part = 0; Part < State.UF; ++Part) {
- Value* Mul = nullptr;
- Value* Phi = nullptr;
- SmallVector<Value*, 2> Ops;
+ Value *Mul = nullptr;
+ Value *Phi = nullptr;
+ SmallVector<Value *, 2> Ops;
for (VPValue *VPOp : operands()) {
auto *Op = State.get(VPOp, Part);
Ops.push_back(Op);
- if(isa<PHINode>(Op))
+ if (isa<PHINode>(Op))
Phi = Op;
else
Mul = Op;
}
assert(Phi && Mul && "Phi and Mul must be set");
- assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+ assert(isa<ScalableVectorType>(Ops[0]->getType()) &&
+ "Type must be a scalable vector");
ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
Intrinsic::ID PartialIntrinsic = Intrinsic::not_intrinsic;
- switch(Opcode) {
+ switch (Opcode) {
case Instruction::Add:
- PartialIntrinsic =
- Intrinsic::experimental_vector_partial_reduce_add;
+ PartialIntrinsic = Intrinsic::experimental_vector_partial_reduce_add;
break;
default:
llvm_unreachable("Opcode not handled");
@@ -283,7 +283,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
assert(PartialIntrinsic != Intrinsic::not_intrinsic);
- Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr, Twine("partial.reduce"));
+ Value *V = Builder.CreateIntrinsic(RetTy, PartialIntrinsic, Mul, nullptr,
+ Twine("partial.reduce"));
V = Builder.CreateNAryOp(Opcode, {V, Phi});
if (auto *VecOp = dyn_cast<Instruction>(V))
setFlags(VecOp);
@@ -295,7 +296,8 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
break;
}
default:
- LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " << Instruction::getOpcodeName(Opcode));
+ LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : "
+ << Instruction::getOpcodeName(Opcode));
llvm_unreachable("Unhandled instruction!");
}
}
@@ -306,7 +308,7 @@ void VPPartialReductionRecipe::postInsertionOp() {
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
- VPSlotTracker &SlotTracker) const {
+ VPSlotTracker &SlotTracker) const {
O << Indent << "PARTIAL-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = " << Instruction::getOpcodeName(Opcode);
@@ -2035,8 +2037,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
// stage #1: We create a new vector PHI node with no incoming edges. We'll use
// this value when we vectorize all of the instructions that use the PHI.
bool ScalarPHI = VF.isScalar() || IsInLoop;
- Type *VecTy = ScalarPHI ? StartV->getType()
- : VectorType::get(StartV->getType(), VF);
+ Type *VecTy =
+ ScalarPHI ? StartV->getType() : VectorType::get(StartV->getType(), VF);
BasicBlock *HeaderBB = State.CFG.PrevBB;
assert(State.CurrentVectorLoop->getHeader() == HeaderBB &&
@@ -2060,8 +2062,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
} else {
IRBuilderBase::InsertPointGuard IPBuilder(Builder);
Builder.SetInsertPoint(VectorPH->getTerminator());
- StartV = Iden =
- Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
+ StartV = Iden = Builder.CreateVectorSplat(VF, StartV, "minmax.ident");
}
} else {
Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),
``````````
</details>
https://github.com/llvm/llvm-project/pull/92418
More information about the llvm-commits
mailing list