[llvm] [LV] Optimise latch exit induction users for some early exit loops (PR #128880)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 26 10:03:02 PST 2025


================
@@ -730,67 +730,74 @@ static VPWidenInductionRecipe *getOptimizableIVOf(VPValue *VPV) {
   return IsWideIVInc() ? WideIV : nullptr;
 }
 
-void VPlanTransforms::optimizeInductionExitUsers(
-    VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
+static VPValue *
+optimizeLatchExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo,
+                               VPBlockBase *PredVPBB, VPValue *Op,
+                               DenseMap<VPValue *, VPValue *> &EndValues) {
   using namespace VPlanPatternMatch;
-  SmallVector<VPIRBasicBlock *> ExitVPBBs(Plan.getExitBlocks());
-  if (ExitVPBBs.size() != 1)
-    return;
 
-  VPIRBasicBlock *ExitVPBB = ExitVPBBs[0];
-  VPBlockBase *PredVPBB = ExitVPBB->getSinglePredecessor();
-  if (!PredVPBB)
-    return;
-  assert(PredVPBB == Plan.getMiddleBlock() &&
-         "predecessor must be the middle block");
-
-  VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType());
-  VPBuilder B(Plan.getMiddleBlock()->getTerminator());
-  for (VPRecipeBase &R : *ExitVPBB) {
-    auto *ExitIRI = cast<VPIRInstruction>(&R);
-    if (!isa<PHINode>(ExitIRI->getInstruction()))
-      break;
+  VPValue *Incoming;
+  if (!match(Op, m_VPInstruction<VPInstruction::ExtractFromEnd>(
+                     m_VPValue(Incoming), m_SpecificInt(1))))
+    return nullptr;
 
-    VPValue *Incoming;
-    if (!match(ExitIRI->getOperand(0),
-               m_VPInstruction<VPInstruction::ExtractFromEnd>(
-                   m_VPValue(Incoming), m_SpecificInt(1))))
-      continue;
+  auto *WideIV = getOptimizableIVOf(Incoming);
+  if (!WideIV)
+    return nullptr;
 
-    auto *WideIV = getOptimizableIVOf(Incoming);
-    if (!WideIV)
-      continue;
-    VPValue *EndValue = EndValues.lookup(WideIV);
-    assert(EndValue && "end value must have been pre-computed");
+  VPValue *EndValue = EndValues.lookup(WideIV);
+  assert(EndValue && "end value must have been pre-computed");
+
+  // This only happens if Incoming is the increment of an induction recipe.
+  if (Incoming != WideIV)
+    return EndValue;
+
+  // Otherwise subtract the step from the EndValue.
+  VPBuilder B(cast<VPBasicBlock>(PredVPBB)->getTerminator());
+  VPValue *Step = WideIV->getStepValue();
+  Type *ScalarTy = TypeInfo.inferScalarType(WideIV);
+  VPValue *Escape = nullptr;
+  if (ScalarTy->isIntegerTy()) {
+    Escape =
+        B.createNaryOp(Instruction::Sub, {EndValue, Step}, {}, "ind.escape");
+  } else if (ScalarTy->isPointerTy()) {
+    auto *Zero = Plan.getOrAddLiveIn(
+        ConstantInt::get(Step->getLiveInIRValue()->getType(), 0));
+    Escape =
+        B.createPtrAdd(EndValue, B.createNaryOp(Instruction::Sub, {Zero, Step}),
+                       {}, "ind.escape");
+  } else if (ScalarTy->isFloatingPointTy()) {
+    const auto &ID = WideIV->getInductionDescriptor();
+    Escape = B.createNaryOp(
+        ID.getInductionBinOp()->getOpcode() == Instruction::FAdd
+            ? Instruction::FSub
+            : Instruction::FAdd,
+        {EndValue, Step}, {ID.getInductionBinOp()->getFastMathFlags()});
+  } else {
+    llvm_unreachable("all possible induction types must be handled");
+  }
+  return Escape;
----------------
MacDue wrote:

nit: No need for `Escape` var or `else ifs`
```suggestion
  if (ScalarTy->isIntegerTy())
    return B.createNaryOp(Instruction::Sub, {EndValue, Step}, {}, "ind.escape");
  if (ScalarTy->isPointerTy()) {
    auto *Zero = Plan.getOrAddLiveIn(
        ConstantInt::get(Step->getLiveInIRValue()->getType(), 0));
    return B.createPtrAdd(EndValue, B.createNaryOp(Instruction::Sub, {Zero, Step}),
                       {}, "ind.escape");
  }
  if (ScalarTy->isFloatingPointTy()) {
    const auto &ID = WideIV->getInductionDescriptor();
    return B.createNaryOp(
        ID.getInductionBinOp()->getOpcode() == Instruction::FAdd
            ? Instruction::FSub
            : Instruction::FAdd,
        {EndValue, Step}, {ID.getInductionBinOp()->getFastMathFlags()});
  }
  llvm_unreachable("all possible induction types must be handled");
```

https://github.com/llvm/llvm-project/pull/128880


More information about the llvm-commits mailing list