[llvm] [Inliner] Add a helper around `SimplifiedValues.lookup`. NFCI (PR #118646)

Marina Taylor via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 4 06:38:53 PST 2024


https://github.com/citymarina created https://github.com/llvm/llvm-project/pull/118646

None

>From 40f339f644488c2093c9c3b28dd6af0e48fbbeee Mon Sep 17 00:00:00 2001
From: Marina Taylor <marina_taylor at apple.com>
Date: Wed, 4 Dec 2024 14:35:09 +0000
Subject: [PATCH] [Inliner] Add a helper around `SimplifiedValues.lookup`. NFCI

---
 llvm/lib/Analysis/InlineCost.cpp | 76 ++++++++++++--------------------
 1 file changed, 27 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 32acf23e1d0d0d..85287a39f2caad 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -422,6 +422,14 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
     return It->second;
   }
 
+  /// Use a value in its given form directly if possible, otherwise try looking
+  /// for it in SimplifiedValues.
+  template <typename T> T *getDirectOrSimplifiedValue(Value *V) const {
+    if (auto *Direct = dyn_cast<T>(V))
+      return Direct;
+    return dyn_cast_if_present<T>(SimplifiedValues.lookup(V));
+  }
+
   // Custom simplification helper routines.
   bool isAllocaDerivedArg(Value *V);
   void disableSROAForArg(AllocaInst *SROAArg);
@@ -1435,10 +1443,8 @@ bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) {
 
   for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
        GTI != GTE; ++GTI) {
-    ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand());
-    if (!OpC)
-      if (Constant *SimpleOp = SimplifiedValues.lookup(GTI.getOperand()))
-        OpC = dyn_cast<ConstantInt>(SimpleOp);
+    ConstantInt *OpC =
+        getDirectOrSimplifiedValue<ConstantInt>(GTI.getOperand());
     if (!OpC)
       return false;
     if (OpC->isZero())
@@ -1552,9 +1558,7 @@ bool CallAnalyzer::visitPHI(PHINode &I) {
     if (&I == V)
       continue;
 
-    Constant *C = dyn_cast<Constant>(V);
-    if (!C)
-      C = SimplifiedValues.lookup(V);
+    Constant *C = getDirectOrSimplifiedValue<Constant>(V);
 
     std::pair<Value *, APInt> BaseAndOffset = {nullptr, ZeroOffset};
     if (!C && CheckSROA)
@@ -1639,7 +1643,7 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
   // Lambda to check whether a GEP's indices are all constant.
   auto IsGEPOffsetConstant = [&](GetElementPtrInst &GEP) {
     for (const Use &Op : GEP.indices())
-      if (!isa<Constant>(Op) && !SimplifiedValues.lookup(Op))
+      if (!getDirectOrSimplifiedValue<Constant>(Op))
         return false;
     return true;
   };
@@ -1666,9 +1670,7 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
 bool CallAnalyzer::simplifyInstruction(Instruction &I) {
   SmallVector<Constant *> COps;
   for (Value *Op : I.operands()) {
-    Constant *COp = dyn_cast<Constant>(Op);
-    if (!COp)
-      COp = SimplifiedValues.lookup(Op);
+    Constant *COp = getDirectOrSimplifiedValue<Constant>(Op);
     if (!COp)
       return false;
     COps.push_back(COp);
@@ -1691,10 +1693,7 @@ bool CallAnalyzer::simplifyInstruction(Instruction &I) {
 /// llvm.is.constant would evaluate.
 bool CallAnalyzer::simplifyIntrinsicCallIsConstant(CallBase &CB) {
   Value *Arg = CB.getArgOperand(0);
-  auto *C = dyn_cast<Constant>(Arg);
-
-  if (!C)
-    C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(Arg));
+  auto *C = getDirectOrSimplifiedValue<Constant>(Arg);
 
   Type *RT = CB.getFunctionType()->getReturnType();
   SimplifiedValues[&CB] = ConstantInt::get(RT, C ? 1 : 0);
@@ -2126,12 +2125,8 @@ bool CallAnalyzer::visitSub(BinaryOperator &I) {
 
 bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
-  Constant *CLHS = dyn_cast<Constant>(LHS);
-  if (!CLHS)
-    CLHS = SimplifiedValues.lookup(LHS);
-  Constant *CRHS = dyn_cast<Constant>(RHS);
-  if (!CRHS)
-    CRHS = SimplifiedValues.lookup(RHS);
+  Constant *CLHS = getDirectOrSimplifiedValue<Constant>(LHS);
+  Constant *CRHS = getDirectOrSimplifiedValue<Constant>(RHS);
 
   Value *SimpleV = nullptr;
   if (auto FI = dyn_cast<FPMathOperator>(&I))
@@ -2165,9 +2160,7 @@ bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) {
 
 bool CallAnalyzer::visitFNeg(UnaryOperator &I) {
   Value *Op = I.getOperand(0);
-  Constant *COp = dyn_cast<Constant>(Op);
-  if (!COp)
-    COp = SimplifiedValues.lookup(Op);
+  Constant *COp = getDirectOrSimplifiedValue<Constant>(Op);
 
   Value *SimpleV = simplifyFNegInst(
       COp ? COp : Op, cast<FPMathOperator>(I).getFastMathFlags(), DL);
@@ -2255,9 +2248,7 @@ bool CallAnalyzer::simplifyCallSite(Function *F, CallBase &Call) {
   SmallVector<Constant *, 4> ConstantArgs;
   ConstantArgs.reserve(Call.arg_size());
   for (Value *I : Call.args()) {
-    Constant *C = dyn_cast<Constant>(I);
-    if (!C)
-      C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(I));
+    Constant *C = getDirectOrSimplifiedValue<Constant>(I);
     if (!C)
       return false; // This argument doesn't map to a constant.
 
@@ -2288,14 +2279,9 @@ bool CallAnalyzer::isLoweredToCall(Function *F, CallBase &Call) {
     // platforms whose headers redirect memcpy to __memcpy_chk (e.g. Darwin), as
     // other platforms use memcpy intrinsics, which are already exempt from the
     // call penalty.
-    auto *LenOp = dyn_cast<ConstantInt>(Call.getOperand(2));
-    if (!LenOp)
-      LenOp = dyn_cast_or_null<ConstantInt>(
-          SimplifiedValues.lookup(Call.getOperand(2)));
-    auto *ObjSizeOp = dyn_cast<ConstantInt>(Call.getOperand(3));
-    if (!ObjSizeOp)
-      ObjSizeOp = dyn_cast_or_null<ConstantInt>(
-          SimplifiedValues.lookup(Call.getOperand(3)));
+    auto *LenOp = getDirectOrSimplifiedValue<ConstantInt>(Call.getOperand(2));
+    auto *ObjSizeOp =
+        getDirectOrSimplifiedValue<ConstantInt>(Call.getOperand(3));
     if (LenOp && ObjSizeOp &&
         LenOp->getLimitedValue() <= ObjSizeOp->getLimitedValue()) {
       return false;
@@ -2411,10 +2397,9 @@ bool CallAnalyzer::visitBranchInst(BranchInst &BI) {
   // shouldn't exist at all, but handling them makes the behavior of the
   // inliner more regular and predictable. Interestingly, conditional branches
   // which will fold away are also free.
-  return BI.isUnconditional() || isa<ConstantInt>(BI.getCondition()) ||
-         BI.getMetadata(LLVMContext::MD_make_implicit) ||
-         isa_and_nonnull<ConstantInt>(
-             SimplifiedValues.lookup(BI.getCondition()));
+  return BI.isUnconditional() ||
+         getDirectOrSimplifiedValue<ConstantInt>(BI.getCondition()) ||
+         BI.getMetadata(LLVMContext::MD_make_implicit);
 }
 
 bool CallAnalyzer::visitSelectInst(SelectInst &SI) {
@@ -2422,12 +2407,8 @@ bool CallAnalyzer::visitSelectInst(SelectInst &SI) {
   Value *TrueVal = SI.getTrueValue();
   Value *FalseVal = SI.getFalseValue();
 
-  Constant *TrueC = dyn_cast<Constant>(TrueVal);
-  if (!TrueC)
-    TrueC = SimplifiedValues.lookup(TrueVal);
-  Constant *FalseC = dyn_cast<Constant>(FalseVal);
-  if (!FalseC)
-    FalseC = SimplifiedValues.lookup(FalseVal);
+  Constant *TrueC = getDirectOrSimplifiedValue<Constant>(TrueVal);
+  Constant *FalseC = getDirectOrSimplifiedValue<Constant>(FalseVal);
   Constant *CondC =
       dyn_cast_or_null<Constant>(SimplifiedValues.lookup(SI.getCondition()));
 
@@ -2497,11 +2478,8 @@ bool CallAnalyzer::visitSelectInst(SelectInst &SI) {
 bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) {
   // We model unconditional switches as free, see the comments on handling
   // branches.
-  if (isa<ConstantInt>(SI.getCondition()))
+  if (getDirectOrSimplifiedValue<ConstantInt>(SI.getCondition()))
     return true;
-  if (Value *V = SimplifiedValues.lookup(SI.getCondition()))
-    if (isa<ConstantInt>(V))
-      return true;
 
   // Assume the most general case where the switch is lowered into
   // either a jump table, bit test, or a balanced binary tree consisting of



More information about the llvm-commits mailing list