[llvm] [Analysis]: Allow inlining recursive call IF recursion depth is 1. (PR #119677)

Hassnaa Hamdi via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 16:38:37 PST 2025


https://github.com/hassnaaHamdi updated https://github.com/llvm/llvm-project/pull/119677

>From 10b54f308deceb7a441682407c154b8a8b968707 Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Thu, 12 Dec 2024 08:48:17 +0000
Subject: [PATCH 1/4] [Analysis]: Allow inlining recursive call IF recursion
 depth is 1.

---
 llvm/include/llvm/Analysis/InlineCost.h       |   2 +-
 llvm/lib/Analysis/InlineCost.cpp              |  80 ++++++++
 .../Transforms/Inline/inline-recursive-fn.ll  | 179 ++++++++++++++++++
 llvm/test/Transforms/Inline/inline-remark.ll  |   2 +-
 4 files changed, 261 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/Inline/inline-recursive-fn.ll

diff --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index ed54b0c077b4a..ac61ca064c28b 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -235,7 +235,7 @@ struct InlineParams {
   std::optional<bool> EnableDeferral;
 
   /// Indicate whether we allow inlining for recursive call.
-  std::optional<bool> AllowRecursiveCall = false;
+  std::optional<bool> AllowRecursiveCall = true;
 };
 
 std::optional<int> getStringFnAttrAsInt(CallBase &CB, StringRef AttrKind);
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 85287a39f2caa..c2ade73b2ddcb 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/DomConditionCache.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
@@ -1160,6 +1161,10 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
   std::optional<CostBenefitPair> getCostBenefitPair() { return CostBenefit; }
   bool wasDecidedByCostBenefit() const { return DecidedByCostBenefit; }
   bool wasDecidedByCostThreshold() const { return DecidedByCostThreshold; }
+  bool shouldCheckRecursiveCall() {
+    return IsRecursiveCall && AllowRecursiveCall;
+  }
+  bool shouldInlineRecursiveCall(CallBase &Call);
 };
 
 // Return true if CB is the sole call to local function Callee.
@@ -2880,6 +2885,68 @@ InlineResult CallAnalyzer::analyze() {
   return finalizeAnalysis();
 }
 
+bool InlineCostCallAnalyzer::shouldInlineRecursiveCall(CallBase &Call) {
+  CallInst *CI = cast<CallInst>(&Call);
+  auto CIB = CI->getParent();
+  // Only handle case when we have sinlge predecessor
+  if (auto Predecessor = CIB->getSinglePredecessor()) {
+    BranchInst *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
+    if (!Br || Br->isUnconditional()) {
+      return false;
+    }
+    Value *Var = Br->getCondition();
+    CmpInst *CmpInstr = dyn_cast<CmpInst>(Var);
+    if (CmpInstr && !isa<Constant>(CmpInstr->getOperand(1))) {
+      // Current logic of ValueTracking/DomConditionCache works only if RHS is
+      // constant.
+      return false;
+    }
+    unsigned ArgNum = 0;
+    Value *FuncArg = nullptr, *CallArg = nullptr;
+    // Check which func argument the cmp instr is using:
+    for (; ArgNum < CI->getFunction()->arg_size(); ArgNum++) {
+      FuncArg = CI->getFunction()->getArg(ArgNum);
+      CallArg = CI->getArgOperand(ArgNum);
+      if (CmpInstr) {
+        if ((FuncArg == CmpInstr->getOperand(0)) &&
+            (CallArg != CmpInstr->getOperand(0)))
+          break;
+      } else if (FuncArg == Var && (CallArg != Var))
+        break;
+    }
+    // Only handle the case when a func argument controls the cmp instruction:
+    if (ArgNum < CI->getFunction()->arg_size()) {
+      bool isTrueSuccessor = CIB == Br->getSuccessor(0);
+      if (CmpInstr) {
+        SimplifyQuery SQ(CI->getFunction()->getDataLayout(),
+                         dyn_cast<Instruction>(CallArg));
+        DomConditionCache DC;
+        DC.registerBranch(Br);
+        SQ.DC = &DC;
+        DominatorTree DT(*CI->getFunction());
+        SQ.DT = &DT;
+        Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(
+            CmpInstr, {CallArg, CmpInstr->getOperand(1)}, SQ);
+        if (!simplifiedInstruction)
+          return false;
+        if (auto *ConstVal =
+                dyn_cast<llvm::ConstantInt>(simplifiedInstruction)) {
+          if (ConstVal->isOne())
+            return !isTrueSuccessor;
+          return isTrueSuccessor;
+        }
+      } else {
+        if (auto *ConstVal = dyn_cast<llvm::ConstantInt>(CallArg)) {
+          if (ConstVal->isOne())
+            return !isTrueSuccessor;
+          return isTrueSuccessor;
+        }
+      }
+    }
+  }
+  return false;
+}
+
 void InlineCostCallAnalyzer::print(raw_ostream &OS) {
 #define DEBUG_PRINT_STAT(x) OS << "      " #x ": " << x << "\n"
   if (PrintInstructionComments)
@@ -3106,6 +3173,12 @@ InlineCost llvm::getInlineCost(
 
   LLVM_DEBUG(CA.dump());
 
+  // Check if callee function is recursive:
+  if (ShouldInline.isSuccess()) {
+    if (CA.shouldCheckRecursiveCall() && !CA.shouldInlineRecursiveCall(Call))
+      return InlineCost::getNever("deep recursion");
+  }
+
   // Always make cost benefit based decision explicit.
   // We use always/never here since threshold is not meaningful,
   // as it's not what drives cost-benefit analysis.
@@ -3148,6 +3221,13 @@ InlineResult llvm::isInlineViable(Function &F) {
 
       // Disallow recursive calls.
       Function *Callee = Call->getCalledFunction();
+      // This function is called when we have "alwaysinline" attribute.
+      // If we allowed the inlining here given that the recursive inlining is
+      // allowed, then there will be problem in the second trial of inlining,
+      // because the Inliner pass allow only one time inlining and then it
+      // inserts "noinline" attribute which will be in conflict with the
+      // attribute of "alwaysinline" so, "alwaysinline" for recursive function
+      // will be disallowed to avoid conflict of attributes.
       if (&F == Callee)
         return InlineResult::failure("recursive call");
 
diff --git a/llvm/test/Transforms/Inline/inline-recursive-fn.ll b/llvm/test/Transforms/Inline/inline-recursive-fn.ll
new file mode 100644
index 0000000000000..db90050d009a0
--- /dev/null
+++ b/llvm/test/Transforms/Inline/inline-recursive-fn.ll
@@ -0,0 +1,179 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes='inline,instcombine' < %s | FileCheck %s
+
+define float @inline_rec_true_successor(float %x, float %scale) {
+; CHECK-LABEL: define float @inline_rec_true_successor(
+; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
+; CHECK-NEXT:    br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
+; CHECK:       [[COMMON_RET18:.*]]:
+; CHECK-NEXT:    [[COMMON_RET18_OP:%.*]] = phi float [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_TRUE_SUCCESSOR_EXIT:.*]] ], [ [[MUL:%.*]], %[[IF_END]] ]
+; CHECK-NEXT:    ret float [[COMMON_RET18_OP]]
+; CHECK:       [[IF_THEN]]:
+; CHECK-NEXT:    br i1 false, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
+; CHECK:       [[IF_THEN_I]]:
+; CHECK-NEXT:    br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
+; CHECK:       [[IF_END_I]]:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
+; CHECK-NEXT:    br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
+; CHECK:       [[INLINE_REC_TRUE_SUCCESSOR_EXIT]]:
+; CHECK-NEXT:    [[COMMON_RET18_OP_I]] = phi float [ poison, %[[IF_THEN_I]] ], [ [[MUL_I]], %[[IF_END_I]] ]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[MUL]] = fmul float [[X]], [[SCALE]]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+;
+entry:
+  %cmp = fcmp olt float %x, 0.000000e+00
+  br i1 %cmp, label %if.then, label %if.end
+
+common.ret18:                                     ; preds = %if.then, %if.end
+  %common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
+  ret float %common.ret18.op
+
+if.then:                                          ; preds = %entry
+  %fneg = fneg float %x
+  %call = tail call float @inline_rec_true_successor(float %fneg, float %scale)
+  br label %common.ret18
+
+if.end:                                           ; preds = %entry
+  %mul = fmul float %x, %scale
+  br label %common.ret18
+}
+
+define float @test_inline_rec_true_successor(float %x, float %scale)  {
+entry:
+  %res = tail call float @inline_rec_true_successor(float %x, float %scale)
+  ret float %res
+}
+
+; Same as previous test except that the recursive callsite is in the false successor
+define float @inline_rec_false_successor(float %x, float %scale) {
+; CHECK-LABEL: define float @inline_rec_false_successor(
+; CHECK-SAME: float [[Y:%.*]], float [[SCALE:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp uge float [[Y]], 0.000000e+00
+; CHECK-NEXT:    br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
+; CHECK:       [[COMMON_RET18:.*]]:
+; CHECK-NEXT:    [[COMMON_RET18_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_FALSE_SUCCESSOR_EXIT:.*]] ]
+; CHECK-NEXT:    ret float [[COMMON_RET18_OP]]
+; CHECK:       [[IF_THEN]]:
+; CHECK-NEXT:    [[MUL]] = fmul float [[Y]], [[SCALE]]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    br i1 true, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
+; CHECK:       [[IF_THEN_I]]:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[Y]]
+; CHECK-NEXT:    [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
+; CHECK-NEXT:    br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
+; CHECK:       [[IF_END_I]]:
+; CHECK-NEXT:    br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
+; CHECK:       [[INLINE_REC_FALSE_SUCCESSOR_EXIT]]:
+; CHECK-NEXT:    [[COMMON_RET18_OP_I]] = phi float [ [[MUL_I]], %[[IF_THEN_I]] ], [ poison, %[[IF_END_I]] ]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+;
+entry:
+  %cmp = fcmp uge float %x, 0.000000e+00
+  br i1 %cmp, label %if.then, label %if.end
+
+common.ret18:                                     ; preds = %if.then, %if.end
+  %common.ret18.op = phi float [ %mul, %if.then ], [ %call, %if.end ]
+  ret float %common.ret18.op
+
+if.then:                                          ; preds = %entry
+  %mul = fmul float %x, %scale
+  br label %common.ret18
+
+if.end:                                           ; preds = %entry
+  %fneg = fneg float %x
+  %call = tail call float @inline_rec_false_successor(float %fneg, float %scale)
+  br label %common.ret18
+}
+
+define float @test_inline_rec_false_successor(float %x, float %scale)  {
+entry:
+  %res = tail call float @inline_rec_false_successor(float %x, float %scale)
+  ret float %res
+}
+
+; Test when the BR has Value not cmp instruction
+define float @inline_rec_no_cmp(i1 %flag, float %scale) {
+; CHECK-LABEL: define float @inline_rec_no_cmp(
+; CHECK-SAME: i1 [[FLAG:%.*]], float [[SCALE:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    br i1 [[FLAG]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
+; CHECK:       [[IF_THEN]]:
+; CHECK-NEXT:    [[SUM:%.*]] = fadd float [[SCALE]], 5.000000e+00
+; CHECK-NEXT:    [[SUM1:%.*]] = fadd float [[SUM]], [[SCALE]]
+; CHECK-NEXT:    br label %[[COMMON_RET:.*]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[SUM2:%.*]] = fadd float [[SCALE]], 5.000000e+00
+; CHECK-NEXT:    br label %[[COMMON_RET]]
+; CHECK:       [[COMMON_RET]]:
+; CHECK-NEXT:    [[COMMON_RET_RES:%.*]] = phi float [ [[SUM1]], %[[IF_THEN]] ], [ [[SUM2]], %[[IF_END]] ]
+; CHECK-NEXT:    ret float [[COMMON_RET_RES]]
+;
+entry:
+  br i1 %flag, label %if.then, label %if.end
+if.then:
+  %res = tail call float @inline_rec_no_cmp(i1 false, float %scale)
+  %sum1 = fadd float %res, %scale
+  br label %common.ret
+if.end:
+  %sum2 = fadd float %scale, 5.000000e+00
+  br label %common.ret
+common.ret:
+  %common.ret.res = phi float [ %sum1, %if.then ], [ %sum2, %if.end ]
+  ret float %common.ret.res
+}
+
+define float @test_inline_rec_no_cmp(i1 %flag, float %scale)  {
+entry:
+  %res = tail call float @inline_rec_no_cmp(i1 %flag, float %scale)
+  ret float %res
+}
+
+define float @no_inline_rec(float %x, float %scale) {
+; CHECK-LABEL: define float @no_inline_rec(
+; CHECK-SAME: float [[Z:%.*]], float [[SCALE:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp olt float [[Z]], 5.000000e+00
+; CHECK-NEXT:    br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
+; CHECK:       [[COMMON_RET18:.*]]:
+; CHECK-NEXT:    [[COMMON_RET18_OP:%.*]] = phi float [ [[FNEG1:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
+; CHECK-NEXT:    ret float [[COMMON_RET18_OP]]
+; CHECK:       [[IF_THEN]]:
+; CHECK-NEXT:    [[FADD:%.*]] = fadd float [[Z]], 5.000000e+00
+; CHECK-NEXT:    [[CALL:%.*]] = tail call float @no_inline_rec(float [[FADD]], float [[SCALE]])
+; CHECK-NEXT:    [[FNEG1]] = fneg float [[CALL]]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+; CHECK:       [[IF_END]]:
+; CHECK-NEXT:    [[MUL]] = fmul float [[Z]], [[SCALE]]
+; CHECK-NEXT:    br label %[[COMMON_RET18]]
+;
+entry:
+  %cmp = fcmp olt float %x, 5.000000e+00
+  br i1 %cmp, label %if.then, label %if.end
+
+common.ret18:                                     ; preds = %if.then, %if.end
+  %common.ret18.op = phi float [ %fneg1, %if.then ], [ %mul, %if.end ]
+  ret float %common.ret18.op
+
+if.then:                                          ; preds = %entry
+  %fadd = fadd float %x, 5.000000e+00
+  %call = tail call float @no_inline_rec(float %fadd, float %scale)
+  %fneg1 = fneg float %call
+  br label %common.ret18
+
+if.end:                                           ; preds = %entry
+  %mul = fmul float %x, %scale
+  br label %common.ret18
+}
+
+define float @test_no_inline(float %x, float %scale)  {
+entry:
+  %res = tail call float @no_inline_rec(float %x, float %scale)
+  ret float %res
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/Inline/inline-remark.ll b/llvm/test/Transforms/Inline/inline-remark.ll
index 166c7bb065924..5af20e42dd31a 100644
--- a/llvm/test/Transforms/Inline/inline-remark.ll
+++ b/llvm/test/Transforms/Inline/inline-remark.ll
@@ -56,6 +56,6 @@ define void @test3() {
 }
 
 ; CHECK: attributes [[ATTR1]] = { "inline-remark"="(cost=25, threshold=0)" }
-; CHECK: attributes [[ATTR2]] = { "inline-remark"="(cost=never): recursive" }
+; CHECK: attributes [[ATTR2]] = { "inline-remark"="(cost=0, threshold=0)" }
 ; CHECK: attributes [[ATTR3]] = { "inline-remark"="unsupported operand bundle; (cost={{.*}}, threshold={{.*}})" }
 ; CHECK: attributes [[ATTR4]] = { alwaysinline "inline-remark"="(cost=never): recursive call" }

>From 85ffb114b0ff0b601931566231cd42a7c865322b Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Tue, 17 Dec 2024 17:32:10 +0000
Subject: [PATCH 2/4] [InlineCost]: fold CmpInstr when next iteration of a
 recursive call can simplify the Cmp decision.

---
 llvm/include/llvm/Analysis/InlineCost.h      |   2 +-
 llvm/lib/Analysis/InlineCost.cpp             | 146 +++++++++----------
 llvm/test/Transforms/Inline/inline-remark.ll |   2 +-
 3 files changed, 69 insertions(+), 81 deletions(-)

diff --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index ac61ca064c28b..ed54b0c077b4a 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -235,7 +235,7 @@ struct InlineParams {
   std::optional<bool> EnableDeferral;
 
   /// Indicate whether we allow inlining for recursive call.
-  std::optional<bool> AllowRecursiveCall = true;
+  std::optional<bool> AllowRecursiveCall = false;
 };
 
 std::optional<int> getStringFnAttrAsInt(CallBase &CB, StringRef AttrKind);
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c2ade73b2ddcb..5d81c503d5178 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -441,6 +441,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
   bool canFoldInboundsGEP(GetElementPtrInst &I);
   bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
   bool simplifyCallSite(Function *F, CallBase &Call);
+  bool simplifyCmpInst(Function *F, CmpInst &Cmp);
   bool simplifyInstruction(Instruction &I);
   bool simplifyIntrinsicCallIsConstant(CallBase &CB);
   bool simplifyIntrinsicCallObjectSize(CallBase &CB);
@@ -1161,10 +1162,6 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
   std::optional<CostBenefitPair> getCostBenefitPair() { return CostBenefit; }
   bool wasDecidedByCostBenefit() const { return DecidedByCostBenefit; }
   bool wasDecidedByCostThreshold() const { return DecidedByCostThreshold; }
-  bool shouldCheckRecursiveCall() {
-    return IsRecursiveCall && AllowRecursiveCall;
-  }
-  bool shouldInlineRecursiveCall(CallBase &Call);
 };
 
 // Return true if CB is the sole call to local function Callee.
@@ -1671,6 +1668,68 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
   return isGEPFree(I);
 }
 
+/// Simplify \p Cmp if RHS is const and we can ValueTrack LHS,
+// This handles the case when the Cmp instruction is guarded a recursive call
+// that will cause the Cmp to fail/succeed for the next iteration.
+bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
+  // Bail out if the RHS is NOT const:
+  if (!isa<Constant>(Cmp.getOperand(1)))
+    return false;
+  auto *CmpOp = Cmp.getOperand(0);
+  // Iterate over the users of the function to check if it's a recursive function:
+  for (auto *U : F->users()) {
+    CallInst *Call = dyn_cast<CallInst>(U);
+    if (!Call || Call->getFunction() != F)
+      continue;
+    auto *CallBB = Call->getParent();
+    auto *Predecessor = CallBB->getSinglePredecessor();
+    // Only handle the case when the callsite has a single predecessor:
+    if (!Predecessor)
+      continue;
+
+    auto *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
+    if (!Br || Br->isUnconditional())
+      continue;
+    // Check if the Br condition is the same Cmp instr we are investigating:
+    auto *CmpInstr = dyn_cast<CmpInst>(Br->getCondition());
+    if (!CmpInstr || CmpInstr != &Cmp)
+      continue;
+    // Check if there are any arg of the recursive callsite is affecting the cmp instr:
+    bool ArgFound = false;
+    Value *FuncArg = nullptr, *CallArg = nullptr;
+    for (unsigned ArgNum = 0; ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum ++) {
+      FuncArg = F->getArg(ArgNum);
+      CallArg = Call->getArgOperand(ArgNum);
+      if ((FuncArg == CmpOp) &&
+          (CallArg != CmpOp)) {
+        ArgFound = true;
+        break;
+      }
+    }
+    if (!ArgFound)
+      continue;
+    // Now we have a recursive call that is guarded by a cmp instruction.
+    // Check if this cmp can be simplified:
+    SimplifyQuery SQ(DL, dyn_cast<Instruction>(CallArg));
+    DomConditionCache DC;
+    DC.registerBranch(Br);
+    SQ.DC = &DC;
+    DominatorTree DT(*F);
+    SQ.DT = &DT;
+    Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(CmpInstr, {CallArg, Cmp.getOperand(1)}, SQ);
+    if (!simplifiedInstruction)
+      continue;
+    if (auto *ConstVal = dyn_cast<llvm::ConstantInt>(simplifiedInstruction)) {
+      bool isTrueSuccessor = CallBB == Br->getSuccessor(0);
+      SimplifiedValues[&Cmp] = ConstVal;
+      if (ConstVal->isOne())
+        return !isTrueSuccessor;
+      return isTrueSuccessor;
+    }
+  }
+  return false;
+}
+
 /// Simplify \p I if its operands are constants and update SimplifiedValues.
 bool CallAnalyzer::simplifyInstruction(Instruction &I) {
   SmallVector<Constant *> COps;
@@ -2055,6 +2114,10 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
   if (simplifyInstruction(I))
     return true;
 
+  // Try to handle comparison that can be simplified using ValueTracking.
+  if (simplifyCmpInst(I.getFunction(), I))
+    return true;
+
   if (I.getOpcode() == Instruction::FCmp)
     return false;
 
@@ -2885,68 +2948,6 @@ InlineResult CallAnalyzer::analyze() {
   return finalizeAnalysis();
 }
 
-bool InlineCostCallAnalyzer::shouldInlineRecursiveCall(CallBase &Call) {
-  CallInst *CI = cast<CallInst>(&Call);
-  auto CIB = CI->getParent();
-  // Only handle case when we have sinlge predecessor
-  if (auto Predecessor = CIB->getSinglePredecessor()) {
-    BranchInst *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
-    if (!Br || Br->isUnconditional()) {
-      return false;
-    }
-    Value *Var = Br->getCondition();
-    CmpInst *CmpInstr = dyn_cast<CmpInst>(Var);
-    if (CmpInstr && !isa<Constant>(CmpInstr->getOperand(1))) {
-      // Current logic of ValueTracking/DomConditionCache works only if RHS is
-      // constant.
-      return false;
-    }
-    unsigned ArgNum = 0;
-    Value *FuncArg = nullptr, *CallArg = nullptr;
-    // Check which func argument the cmp instr is using:
-    for (; ArgNum < CI->getFunction()->arg_size(); ArgNum++) {
-      FuncArg = CI->getFunction()->getArg(ArgNum);
-      CallArg = CI->getArgOperand(ArgNum);
-      if (CmpInstr) {
-        if ((FuncArg == CmpInstr->getOperand(0)) &&
-            (CallArg != CmpInstr->getOperand(0)))
-          break;
-      } else if (FuncArg == Var && (CallArg != Var))
-        break;
-    }
-    // Only handle the case when a func argument controls the cmp instruction:
-    if (ArgNum < CI->getFunction()->arg_size()) {
-      bool isTrueSuccessor = CIB == Br->getSuccessor(0);
-      if (CmpInstr) {
-        SimplifyQuery SQ(CI->getFunction()->getDataLayout(),
-                         dyn_cast<Instruction>(CallArg));
-        DomConditionCache DC;
-        DC.registerBranch(Br);
-        SQ.DC = &DC;
-        DominatorTree DT(*CI->getFunction());
-        SQ.DT = &DT;
-        Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(
-            CmpInstr, {CallArg, CmpInstr->getOperand(1)}, SQ);
-        if (!simplifiedInstruction)
-          return false;
-        if (auto *ConstVal =
-                dyn_cast<llvm::ConstantInt>(simplifiedInstruction)) {
-          if (ConstVal->isOne())
-            return !isTrueSuccessor;
-          return isTrueSuccessor;
-        }
-      } else {
-        if (auto *ConstVal = dyn_cast<llvm::ConstantInt>(CallArg)) {
-          if (ConstVal->isOne())
-            return !isTrueSuccessor;
-          return isTrueSuccessor;
-        }
-      }
-    }
-  }
-  return false;
-}
-
 void InlineCostCallAnalyzer::print(raw_ostream &OS) {
 #define DEBUG_PRINT_STAT(x) OS << "      " #x ": " << x << "\n"
   if (PrintInstructionComments)
@@ -3173,12 +3174,6 @@ InlineCost llvm::getInlineCost(
 
   LLVM_DEBUG(CA.dump());
 
-  // Check if callee function is recursive:
-  if (ShouldInline.isSuccess()) {
-    if (CA.shouldCheckRecursiveCall() && !CA.shouldInlineRecursiveCall(Call))
-      return InlineCost::getNever("deep recursion");
-  }
-
   // Always make cost benefit based decision explicit.
   // We use always/never here since threshold is not meaningful,
   // as it's not what drives cost-benefit analysis.
@@ -3221,13 +3216,6 @@ InlineResult llvm::isInlineViable(Function &F) {
 
       // Disallow recursive calls.
       Function *Callee = Call->getCalledFunction();
-      // This function is called when we have "alwaysinline" attribute.
-      // If we allowed the inlining here given that the recursive inlining is
-      // allowed, then there will be problem in the second trial of inlining,
-      // because the Inliner pass allow only one time inlining and then it
-      // inserts "noinline" attribute which will be in conflict with the
-      // attribute of "alwaysinline" so, "alwaysinline" for recursive function
-      // will be disallowed to avoid conflict of attributes.
       if (&F == Callee)
         return InlineResult::failure("recursive call");
 
diff --git a/llvm/test/Transforms/Inline/inline-remark.ll b/llvm/test/Transforms/Inline/inline-remark.ll
index 5af20e42dd31a..166c7bb065924 100644
--- a/llvm/test/Transforms/Inline/inline-remark.ll
+++ b/llvm/test/Transforms/Inline/inline-remark.ll
@@ -56,6 +56,6 @@ define void @test3() {
 }
 
 ; CHECK: attributes [[ATTR1]] = { "inline-remark"="(cost=25, threshold=0)" }
-; CHECK: attributes [[ATTR2]] = { "inline-remark"="(cost=0, threshold=0)" }
+; CHECK: attributes [[ATTR2]] = { "inline-remark"="(cost=never): recursive" }
 ; CHECK: attributes [[ATTR3]] = { "inline-remark"="unsupported operand bundle; (cost={{.*}}, threshold={{.*}})" }
 ; CHECK: attributes [[ATTR4]] = { alwaysinline "inline-remark"="(cost=never): recursive call" }

>From 67e9d4e4aae271d24a9df0a4bfad8614ac4ceef8 Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Tue, 17 Dec 2024 17:32:44 +0000
Subject: [PATCH 3/4] format

---
 llvm/lib/Analysis/InlineCost.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 5d81c503d5178..f4304bc741ca6 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -1676,7 +1676,8 @@ bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
   if (!isa<Constant>(Cmp.getOperand(1)))
     return false;
   auto *CmpOp = Cmp.getOperand(0);
-  // Iterate over the users of the function to check if it's a recursive function:
+  // Iterate over the users of the function to check if it's a recursive
+  // function:
   for (auto *U : F->users()) {
     CallInst *Call = dyn_cast<CallInst>(U);
     if (!Call || Call->getFunction() != F)
@@ -1694,14 +1695,15 @@ bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
     auto *CmpInstr = dyn_cast<CmpInst>(Br->getCondition());
     if (!CmpInstr || CmpInstr != &Cmp)
       continue;
-    // Check if there are any arg of the recursive callsite is affecting the cmp instr:
+    // Check if there are any arg of the recursive callsite is affecting the cmp
+    // instr:
     bool ArgFound = false;
     Value *FuncArg = nullptr, *CallArg = nullptr;
-    for (unsigned ArgNum = 0; ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum ++) {
+    for (unsigned ArgNum = 0;
+         ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum++) {
       FuncArg = F->getArg(ArgNum);
       CallArg = Call->getArgOperand(ArgNum);
-      if ((FuncArg == CmpOp) &&
-          (CallArg != CmpOp)) {
+      if ((FuncArg == CmpOp) && (CallArg != CmpOp)) {
         ArgFound = true;
         break;
       }
@@ -1716,7 +1718,8 @@ bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
     SQ.DC = &DC;
     DominatorTree DT(*F);
     SQ.DT = &DT;
-    Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(CmpInstr, {CallArg, Cmp.getOperand(1)}, SQ);
+    Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(
+        CmpInstr, {CallArg, Cmp.getOperand(1)}, SQ);
     if (!simplifiedInstruction)
       continue;
     if (auto *ConstVal = dyn_cast<llvm::ConstantInt>(simplifiedInstruction)) {

>From 4dd37cb3c72825394138f09c4bcddd1210b3ff63 Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Wed, 26 Feb 2025 00:38:19 +0000
Subject: [PATCH 4/4] Use DT only when we get different functions

Change-Id: I296a55dba127c1ae74cca9e0edcb10f2f44640e9
---
 llvm/lib/Analysis/InlineCost.cpp | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index f4304bc741ca6..9562a7ed77a58 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -262,6 +262,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
   // Cache the DataLayout since we use it a lot.
   const DataLayout &DL;
 
+  DominatorTree DT;
+
   /// The OptimizationRemarkEmitter available for this compilation.
   OptimizationRemarkEmitter *ORE;
 
@@ -1716,7 +1718,15 @@ bool CallAnalyzer::simplifyCmpInst(Function *F, CmpInst &Cmp) {
     DomConditionCache DC;
     DC.registerBranch(Br);
     SQ.DC = &DC;
-    DominatorTree DT(*F);
+    if (DT.root_size() == 0) {
+      // Dominator tree was never constructed for any function yet.
+      DT.recalculate(*F);
+    }
+    else if (DT.getRoot()->getParent() != F) {
+      // Dominator tree was constructed for a different function, recalculate
+      // it for the current function.
+      DT.recalculate(*F);
+    }
     SQ.DT = &DT;
     Value *simplifiedInstruction = llvm::simplifyInstructionWithOperands(
         CmpInstr, {CallArg, Cmp.getOperand(1)}, SQ);



More information about the llvm-commits mailing list