[llvm] [VPlan] Update final IV exit value via VPlan. (PR #112147)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 17 13:26:25 PST 2025
================
@@ -666,6 +666,134 @@ static void legalizeAndOptimizeInductions(VPlan &Plan) {
}
}
+/// Return a wide IV, if \p VPV is an optimizable wide IV or wide IV use. That
+/// is, if \p VPV is either an untruncated wide induction, or if it increments a
+/// wide induction by its step.
+static VPWidenInductionRecipe *isOptimizableIVOrUse(VPValue *VPV) {
+ auto *WideIV = dyn_cast<VPWidenInductionRecipe>(VPV);
+ if (WideIV) {
+ // VPV itself is a wide induction, separately compute the end value for exit
+ // users if it is not a truncated IV.
+ if (isa<VPWidenPointerInductionRecipe>(WideIV) ||
+ !cast<VPWidenIntOrFpInductionRecipe>(WideIV)->getTruncInst())
+ return WideIV;
+ return nullptr;
+ }
+
+ // Check if VPV is an optimizable induction increment.
+ VPRecipeBase *Def = VPV->getDefiningRecipe();
+ if (!Def || Def->getNumOperands() != 2)
+ return nullptr;
+ WideIV = dyn_cast<VPWidenInductionRecipe>(Def->getOperand(0));
+ if (!WideIV)
+ WideIV = dyn_cast<VPWidenInductionRecipe>(Def->getOperand(1));
+ if (!WideIV)
+ return nullptr;
+
+ auto IsWideIVInc = [&]() {
+ using namespace VPlanPatternMatch;
+ auto &ID = WideIV->getInductionDescriptor();
+
+ // Check if VPV increments the induction by the induction step.
+ VPValue *IVStep = WideIV->getStepValue();
+ switch (ID.getInductionOpcode()) {
+ case Instruction::Add:
+ return match(VPV, m_c_Binary<Instruction::Add>(m_Specific(WideIV),
+ m_Specific(IVStep)));
+ case Instruction::FAdd:
+ return match(VPV, m_c_Binary<Instruction::FAdd>(m_Specific(WideIV),
+ m_Specific(IVStep)));
+ case Instruction::FSub:
+ return match(VPV, m_Binary<Instruction::FSub>(m_Specific(WideIV),
+ m_Specific(IVStep)));
+ case Instruction::Sub: {
+ // IVStep will be the negated step of the subtraction. Check if Step == -1
+ // * IVStep.
+ VPValue *Step;
+ if (!match(VPV,
+ m_Binary<Instruction::Sub>(m_VPValue(), m_VPValue(Step))) ||
+ !Step->isLiveIn() || !IVStep->isLiveIn())
+ return false;
+ auto *StepCI = dyn_cast<ConstantInt>(Step->getLiveInIRValue());
+ auto *IVStepCI = dyn_cast<ConstantInt>(IVStep->getLiveInIRValue());
+ return StepCI && IVStepCI &&
+ StepCI->getValue() == (-1 * IVStepCI->getValue());
+ }
+ default:
+ return ID.getKind() == InductionDescriptor::IK_PtrInduction &&
+ match(VPV, m_GetElementPtr(m_Specific(WideIV),
+ m_Specific(WideIV->getStepValue())));
+ }
+ llvm_unreachable("should have been covered by switch above");
+ };
+ return IsWideIVInc() ? WideIV : nullptr;
+}
+
+void VPlanTransforms::optimizeInductionExitUsers(
+ VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
+ using namespace VPlanPatternMatch;
+ SmallVector<VPIRBasicBlock *> ExitVPBBs(Plan.getExitBlocks());
+ if (ExitVPBBs.size() != 1)
+ return;
+
+ VPIRBasicBlock *ExitVPBB = ExitVPBBs[0];
+ VPBlockBase *PredVPBB = ExitVPBB->getSinglePredecessor();
+ if (!PredVPBB)
+ return;
+ assert(PredVPBB == Plan.getMiddleBlock() &&
+ "predecessor must be the middle block");
+
+ VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType());
+ VPBuilder B(Plan.getMiddleBlock()->getTerminator());
+ for (VPRecipeBase &R : *ExitVPBB) {
+ auto *ExitIRI = cast<VPIRInstruction>(&R);
+ if (!isa<PHINode>(ExitIRI->getInstruction()))
+ break;
+
+ VPValue *Incoming;
+ if (!match(ExitIRI->getOperand(0),
+ m_VPInstruction<VPInstruction::ExtractFromEnd>(
+ m_VPValue(Incoming), m_SpecificInt(1))))
+ continue;
+
+ auto *WideIV = isOptimizableIVOrUse(Incoming);
+ if (!WideIV)
+ continue;
+ VPValue *EndValue = EndValues.lookup(WideIV);
+ if (!EndValue)
+ continue;
+
+ if (Incoming != WideIV) {
+ ExitIRI->setOperand(0, EndValue);
+ continue;
+ }
+
+ VPValue *Escape = nullptr;
+ VPValue *Step = WideIV->getStepValue();
+ Type *ScalarTy = TypeInfo.inferScalarType(WideIV);
+ if (ScalarTy->isIntegerTy()) {
+ Escape =
+ B.createNaryOp(Instruction::Sub, {EndValue, Step}, {}, "ind.escape");
+ } else if (ScalarTy->isPointerTy()) {
+ auto *Zero = Plan.getOrAddLiveIn(
+ ConstantInt::get(Step->getLiveInIRValue()->getType(), 0));
+ Escape = B.createPtrAdd(EndValue,
+ B.createNaryOp(Instruction::Sub, {Zero, Step}),
+ {}, "ind.escape");
+ } else if (ScalarTy->isFloatingPointTy()) {
+ const auto &ID = WideIV->getInductionDescriptor();
+ Escape = B.createNaryOp(
+ ID.getInductionBinOp()->getOpcode() == Instruction::FAdd
+ ? Instruction::FSub
+ : Instruction::FAdd,
+ {EndValue, Step}, {ID.getInductionBinOp()->getFastMathFlags()});
+ } else {
+ llvm_unreachable("all possible induction types must be handled");
+ }
----------------
fhahn wrote:
Yep, should be part of TODO added, thanks
https://github.com/llvm/llvm-project/pull/112147
More information about the llvm-commits
mailing list