[llvm] [SystemZ] Remove high inlining threshold multiplier. (PR #106058)

Jonas Paulsson via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 30 07:48:08 PDT 2024


https://github.com/JonPsson1 updated https://github.com/llvm/llvm-project/pull/106058

>From 876ae9d28b0e68dded76e9687cc50782b18c87a8 Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Thu, 8 Aug 2024 12:00:00 +0200
Subject: [PATCH 1/4] Experiments

---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 223 +++++++++++++++++-
 .../SystemZ/SystemZTargetTransformInfo.h      |   2 +-
 2 files changed, 223 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index e44777c5c48575..d9c17cd2b1311a 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -53,20 +53,241 @@ static bool isUsedAsMemCpySource(const Value *V, bool &OtherUse) {
   return UsedAsMemCpySource;
 }
 
+static void countNumMemAccesses(const Value *Ptr, unsigned &NumStores,
+                                unsigned &NumLoads, const Function *F = nullptr) {
+  if (!isa<PointerType>(Ptr->getType()))
+    return;
+  for (const User *U : Ptr->users())
+    if (const Instruction *User = dyn_cast<Instruction>(U)) {
+      if (User->getParent()->getParent() == F || !F) {
+        if (const auto *SI = dyn_cast<StoreInst>(User)) {
+          if (SI->getPointerOperand() == Ptr && !SI->isVolatile())
+            NumStores++;
+        }
+        else if (const auto *LI = dyn_cast<LoadInst>(User)) {
+          if (LI->getPointerOperand() == Ptr && !LI->isVolatile())
+            NumLoads++;
+        }
+        else if (const auto *GEP = dyn_cast<GetElementPtrInst>(User)) {
+          if (GEP->getPointerOperand() == Ptr)
+            countNumMemAccesses(GEP, NumStores, NumLoads);
+        }
+      }
+    }
+}
+
+static unsigned usesAroundCall(const CallBase *CB, const GlobalVariable *GV) {
+  unsigned Uses = 0;
+  std::set<const Value *> Ptrs;
+  Ptrs.insert(GV);
+
+  const BasicBlock *BB = CB->getParent();
+  const unsigned CutOff = 20;
+  BasicBlock::const_iterator II = CB->getIterator();
+  for (unsigned N = 0; N < CutOff && II != BB->begin(); N++)
+    II--;
+  BasicBlock::const_iterator EE = CB->getIterator();
+  for (unsigned N = 0; N < CutOff && EE != BB->end(); N++)
+    EE++;
+  
+  for (; II != EE; ++II) {
+    if (const auto *SI = dyn_cast<StoreInst>(II)) {
+      if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
+        Uses++;
+    }
+    else if (const auto *LI = dyn_cast<LoadInst>(II)) {
+      if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
+        Uses++;
+    }
+    else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
+      if (Ptrs.count(GEP->getPointerOperand()))
+        Ptrs.insert(GEP);
+    }
+  }
+  return Uses;
+}
+
+static unsigned usesEntryExit(const Function *F, const GlobalVariable *GV) {
+  unsigned Uses = 0;
+  std::set<const Value *> Ptrs;
+  Ptrs.insert(GV);
+
+  const unsigned CutOff = 100;
+  const BasicBlock *BB = &F->getEntryBlock();
+  unsigned N = 0;
+  for (BasicBlock::const_iterator II = BB->begin();
+       II != BB->end() && N < CutOff; ++II, N++) {
+    if (const auto *SI = dyn_cast<StoreInst>(II)) {
+      if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
+        Uses++;
+    }
+    else if (const auto *LI = dyn_cast<LoadInst>(II)) {
+      if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
+        Uses++;
+    }
+    else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
+      if (Ptrs.count(GEP->getPointerOperand()))
+        Ptrs.insert(GEP);
+    }
+  }
+
+  Ptrs.clear();
+  Ptrs.insert(GV);
+  unsigned ReturnBlockUses = 0;
+  unsigned NumReturnBlocks = 0;
+  for (auto &BBII : *F) {
+    if (isa<ReturnInst>(BBII.getTerminator())) {
+      if (NumReturnBlocks++ > 0) {
+        ReturnBlockUses = 0;
+        break;
+      }
+      BasicBlock::const_iterator EE = BBII.getTerminator()->getIterator();
+      BasicBlock::const_iterator II = EE;
+      for (unsigned N = 0; N < CutOff && II != BBII.begin(); N++)
+        II--;
+      for (; II != EE; ++II) {
+        if (const auto *SI = dyn_cast<StoreInst>(II)) {
+          if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
+            ReturnBlockUses++;
+        }
+        else if (const auto *LI = dyn_cast<LoadInst>(II)) {
+          if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
+            ReturnBlockUses++;
+        }
+        else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
+          if (Ptrs.count(GEP->getPointerOperand()))
+            Ptrs.insert(GEP);
+        }
+      }
+    }
+  }
+
+  return Uses + ReturnBlockUses;
+}
+
 unsigned SystemZTTIImpl::adjustInliningThreshold(const CallBase *CB) const {
   unsigned Bonus = 0;
 
+
+  // dbgs() << "INSTRCOUNT: " << CB->getCalledFunction()->getInstructionCount()
+  //        << CB->getCalledFunction()->getName() << "\n";
+  // if (CB->getCalledFunction()->getInstructionCount() == 216)
+  //   Bonus = 300;
+
+  // if (Function *Callee = CB->getCalledFunction()) {
+  //   const char *CallerFunName = CB->getParent()->getParent()->getName().data();
+  //   const char *CalleeFunName = Callee->getName().data();
+
+  //   if (std::strcmp(CallerFunName , "S_regmatch") == 0) {
+  //     if (std::strcmp(CalleeFunName, "S_reghopmaybe3") == 0 ||  // less important
+  //         std::strcmp(CalleeFunName, "S_regcppop") == 0 ||
+  //         std::strcmp(CalleeFunName, "S_regcppush") == 0)
+  //       return 250;
+  //   }
+  // }
+
+  // Check inlining with memory accesses common to caller and callee
+  // - Around call in caller?  entry/exit blocks in callee?
+  // - Globals used (much?) in both caller and callee
+  // - Specific type of pattern: load; inc/dec; store ?
+  // - non-volatile loads/stores?
+  // - int/fp loads/stores?  ptr?
+  // - num occurences in caller?
+  // - or specifically 2+ functions inlined if many common accesses?
+  // - specifically 2+ functions getting same adress as argument (ptr)?
+  // - (ptr-args generally?)
+  if (const Function *Callee = CB->getCalledFunction()) {
+    const Function *Caller = CB->getParent()->getParent();
+    const Module *M = Caller->getParent();
+    std::set<const GlobalVariable *> CalleeGlobals;
+    std::set<const GlobalVariable *> CallerGlobals;
+    for (const GlobalVariable &Global : M->globals())
+      for (const User *U : Global.users())
+        if (const Instruction *User = dyn_cast<Instruction>(U)) {
+          if (User->getParent()->getParent() == Callee)
+            CalleeGlobals.insert(&Global);
+          if (User->getParent()->getParent() == Caller)
+            CallerGlobals.insert(&Global);
+        }
+
+    for (auto *GV : CalleeGlobals)
+      if (CallerGlobals.count(GV)) {
+        unsigned CalleeStores = 0, CalleeLoads = 0;
+        unsigned CallerStores = 0, CallerLoads = 0;
+        countNumMemAccesses(GV, CalleeStores, CalleeLoads, Callee);
+        countNumMemAccesses(GV, CallerStores, CallerLoads, Caller);
+        if ((CalleeStores || CalleeLoads) && (CallerStores || CallerLoads)) {
+          // dbgs() << "GV: @" << GV->getName()
+          //        << " " << *GV->getValueType()
+          //        << "  Callee: " << Callee->getName() << " S: " << CalleeStores
+          //        << " L: " << CalleeLoads << " MEE: " << (CalleeStores + CalleeLoads)
+          //        << " Callee-size: " << Callee->getInstructionCount()
+          //        << "  Caller: " << Caller->getName() << " S: " << CallerStores
+          //        << " L: " << CallerLoads << " MER: " << (CallerStores + CallerLoads)
+          //        << " Uses-around-call: " << usesAroundCall(CB, GV)
+          //        << " Uses-entry-exit-callee: " << usesEntryExit(Callee, GV)
+          //        << "\n";
+
+          // const char *CallerFunName = CB->getParent()->getParent()->getName().data();
+          // const char *CalleeFunName = Callee->getName().data();
+          //            if (std::strcmp(CallerFunName , "S_regmatch") == 0) {
+          // if (std::strcmp(CalleeFunName, "S_regcppop") == 0) {
+          //     return 250;
+          // }
+          // if (std::strcmp(CalleeFunName, "S_regcppush") == 0) {
+          //     return 250;
+          // }
+          if (//usesEntryExit(Callee, GV) >= 5 &&
+              Callee->getInstructionCount() < 250 &&
+
+              //              (CalleeStores >= 5 && CalleeLoads >= 5) && 
+              (CalleeStores + CalleeLoads) > 10 &&
+
+              // CallerLoads > 25)
+              (CallerStores + CallerLoads) > 10)
+            return 500;
+
+          //  if 
+          // if ((CallerStores + CallerLoads) > 25)
+          //                 if (CallerLoads) > 25)
+
+          //}
+        }
+      }
+  }
+
   // Increase the threshold if an incoming argument is used only as a memcpy
   // source.
   if (Function *Callee = CB->getCalledFunction())
     for (Argument &Arg : Callee->args()) {
       bool OtherUse = false;
       if (isUsedAsMemCpySource(&Arg, OtherUse) && !OtherUse)
-        Bonus += 150;
+        Bonus += 1000;
     }
 
+  if (!Bonus) {
+    if (Function *Callee = CB->getCalledFunction()) {
+      unsigned NumStores = 0;
+      unsigned NumLoads = 0;
+      for (unsigned OpIdx = 0; OpIdx != Callee->arg_size(); ++OpIdx) {
+        Value    *CallerArg = CB->getArgOperand(OpIdx);
+        Argument *CalleeArg = Callee->getArg(OpIdx);
+        if (isa<AllocaInst>(CallerArg))
+          countNumMemAccesses(CalleeArg, NumStores, NumLoads);
+      }
+      //      dbgs() << "NUM: " << NumStores << " " << NumLoads << "\n";
+      // Best on povray, but not doing stores slightly better on blender.
+      if (NumLoads > 10)
+        Bonus += NumLoads * 50;
+      if (NumStores > 10)
+        Bonus += NumStores * 50;
+      Bonus = std::min(Bonus, unsigned(1000));
+    }
+  }
+
   LLVM_DEBUG(if (Bonus)
                dbgs() << "++ SZTTI Adding inlining bonus: " << Bonus << "\n";);
+
   return Bonus;
 }
 
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index e221200cfa08c4..8e5df4ee270020 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -38,7 +38,7 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
   /// \name Scalar TTI Implementations
   /// @{
 
-  unsigned getInliningThresholdMultiplier() const { return 3; }
+  unsigned getInliningThresholdMultiplier() const { return 1; }
   unsigned adjustInliningThreshold(const CallBase *CB) const;
 
   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,

>From e65dba49d7af5b4aa50c997cd9ae34c1b37ef73d Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Fri, 16 Aug 2024 10:02:56 +0200
Subject: [PATCH 2/4] Rewrite

---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 250 ++++--------------
 .../CodeGen/SystemZ/inline-thresh-adjust.ll   | 139 +++++++++-
 2 files changed, 185 insertions(+), 204 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index d9c17cd2b1311a..b98f11d878aebb 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -54,12 +54,12 @@ static bool isUsedAsMemCpySource(const Value *V, bool &OtherUse) {
 }
 
 static void countNumMemAccesses(const Value *Ptr, unsigned &NumStores,
-                                unsigned &NumLoads, const Function *F = nullptr) {
+                                unsigned &NumLoads, const Function *F) {
   if (!isa<PointerType>(Ptr->getType()))
     return;
   for (const User *U : Ptr->users())
     if (const Instruction *User = dyn_cast<Instruction>(U)) {
-      if (User->getParent()->getParent() == F || !F) {
+      if (User->getParent()->getParent() == F) {
         if (const auto *SI = dyn_cast<StoreInst>(User)) {
           if (SI->getPointerOperand() == Ptr && !SI->isVolatile())
             NumStores++;
@@ -70,220 +70,68 @@ static void countNumMemAccesses(const Value *Ptr, unsigned &NumStores,
         }
         else if (const auto *GEP = dyn_cast<GetElementPtrInst>(User)) {
           if (GEP->getPointerOperand() == Ptr)
-            countNumMemAccesses(GEP, NumStores, NumLoads);
+            countNumMemAccesses(GEP, NumStores, NumLoads, F);
         }
       }
     }
 }
 
-static unsigned usesAroundCall(const CallBase *CB, const GlobalVariable *GV) {
-  unsigned Uses = 0;
-  std::set<const Value *> Ptrs;
-  Ptrs.insert(GV);
-
-  const BasicBlock *BB = CB->getParent();
-  const unsigned CutOff = 20;
-  BasicBlock::const_iterator II = CB->getIterator();
-  for (unsigned N = 0; N < CutOff && II != BB->begin(); N++)
-    II--;
-  BasicBlock::const_iterator EE = CB->getIterator();
-  for (unsigned N = 0; N < CutOff && EE != BB->end(); N++)
-    EE++;
-  
-  for (; II != EE; ++II) {
-    if (const auto *SI = dyn_cast<StoreInst>(II)) {
-      if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
-        Uses++;
-    }
-    else if (const auto *LI = dyn_cast<LoadInst>(II)) {
-      if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
-        Uses++;
-    }
-    else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
-      if (Ptrs.count(GEP->getPointerOperand()))
-        Ptrs.insert(GEP);
-    }
-  }
-  return Uses;
-}
-
-static unsigned usesEntryExit(const Function *F, const GlobalVariable *GV) {
-  unsigned Uses = 0;
-  std::set<const Value *> Ptrs;
-  Ptrs.insert(GV);
-
-  const unsigned CutOff = 100;
-  const BasicBlock *BB = &F->getEntryBlock();
-  unsigned N = 0;
-  for (BasicBlock::const_iterator II = BB->begin();
-       II != BB->end() && N < CutOff; ++II, N++) {
-    if (const auto *SI = dyn_cast<StoreInst>(II)) {
-      if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
-        Uses++;
-    }
-    else if (const auto *LI = dyn_cast<LoadInst>(II)) {
-      if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
-        Uses++;
-    }
-    else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
-      if (Ptrs.count(GEP->getPointerOperand()))
-        Ptrs.insert(GEP);
-    }
-  }
-
-  Ptrs.clear();
-  Ptrs.insert(GV);
-  unsigned ReturnBlockUses = 0;
-  unsigned NumReturnBlocks = 0;
-  for (auto &BBII : *F) {
-    if (isa<ReturnInst>(BBII.getTerminator())) {
-      if (NumReturnBlocks++ > 0) {
-        ReturnBlockUses = 0;
-        break;
-      }
-      BasicBlock::const_iterator EE = BBII.getTerminator()->getIterator();
-      BasicBlock::const_iterator II = EE;
-      for (unsigned N = 0; N < CutOff && II != BBII.begin(); N++)
-        II--;
-      for (; II != EE; ++II) {
-        if (const auto *SI = dyn_cast<StoreInst>(II)) {
-          if (Ptrs.count(SI->getPointerOperand()) && !SI->isVolatile())
-            ReturnBlockUses++;
-        }
-        else if (const auto *LI = dyn_cast<LoadInst>(II)) {
-          if (Ptrs.count(LI->getPointerOperand()) && !LI->isVolatile())
-            ReturnBlockUses++;
-        }
-        else if (const auto *GEP = dyn_cast<GetElementPtrInst>(II)) {
-          if (Ptrs.count(GEP->getPointerOperand()))
-            Ptrs.insert(GEP);
-        }
-      }
-    }
-  }
-
-  return Uses + ReturnBlockUses;
-}
-
 unsigned SystemZTTIImpl::adjustInliningThreshold(const CallBase *CB) const {
   unsigned Bonus = 0;
-
-
-  // dbgs() << "INSTRCOUNT: " << CB->getCalledFunction()->getInstructionCount()
-  //        << CB->getCalledFunction()->getName() << "\n";
-  // if (CB->getCalledFunction()->getInstructionCount() == 216)
-  //   Bonus = 300;
-
-  // if (Function *Callee = CB->getCalledFunction()) {
-  //   const char *CallerFunName = CB->getParent()->getParent()->getName().data();
-  //   const char *CalleeFunName = Callee->getName().data();
-
-  //   if (std::strcmp(CallerFunName , "S_regmatch") == 0) {
-  //     if (std::strcmp(CalleeFunName, "S_reghopmaybe3") == 0 ||  // less important
-  //         std::strcmp(CalleeFunName, "S_regcppop") == 0 ||
-  //         std::strcmp(CalleeFunName, "S_regcppush") == 0)
-  //       return 250;
-  //   }
-  // }
-
-  // Check inlining with memory accesses common to caller and callee
-  // - Around call in caller?  entry/exit blocks in callee?
-  // - Globals used (much?) in both caller and callee
-  // - Specific type of pattern: load; inc/dec; store ?
-  // - non-volatile loads/stores?
-  // - int/fp loads/stores?  ptr?
-  // - num occurences in caller?
-  // - or specifically 2+ functions inlined if many common accesses?
-  // - specifically 2+ functions getting same adress as argument (ptr)?
-  // - (ptr-args generally?)
-  if (const Function *Callee = CB->getCalledFunction()) {
-    const Function *Caller = CB->getParent()->getParent();
-    const Module *M = Caller->getParent();
-    std::set<const GlobalVariable *> CalleeGlobals;
-    std::set<const GlobalVariable *> CallerGlobals;
-    for (const GlobalVariable &Global : M->globals())
-      for (const User *U : Global.users())
-        if (const Instruction *User = dyn_cast<Instruction>(U)) {
-          if (User->getParent()->getParent() == Callee)
-            CalleeGlobals.insert(&Global);
-          if (User->getParent()->getParent() == Caller)
-            CallerGlobals.insert(&Global);
-        }
-
-    for (auto *GV : CalleeGlobals)
-      if (CallerGlobals.count(GV)) {
-        unsigned CalleeStores = 0, CalleeLoads = 0;
-        unsigned CallerStores = 0, CallerLoads = 0;
-        countNumMemAccesses(GV, CalleeStores, CalleeLoads, Callee);
-        countNumMemAccesses(GV, CallerStores, CallerLoads, Caller);
-        if ((CalleeStores || CalleeLoads) && (CallerStores || CallerLoads)) {
-          // dbgs() << "GV: @" << GV->getName()
-          //        << " " << *GV->getValueType()
-          //        << "  Callee: " << Callee->getName() << " S: " << CalleeStores
-          //        << " L: " << CalleeLoads << " MEE: " << (CalleeStores + CalleeLoads)
-          //        << " Callee-size: " << Callee->getInstructionCount()
-          //        << "  Caller: " << Caller->getName() << " S: " << CallerStores
-          //        << " L: " << CallerLoads << " MER: " << (CallerStores + CallerLoads)
-          //        << " Uses-around-call: " << usesAroundCall(CB, GV)
-          //        << " Uses-entry-exit-callee: " << usesEntryExit(Callee, GV)
-          //        << "\n";
-
-          // const char *CallerFunName = CB->getParent()->getParent()->getName().data();
-          // const char *CalleeFunName = Callee->getName().data();
-          //            if (std::strcmp(CallerFunName , "S_regmatch") == 0) {
-          // if (std::strcmp(CalleeFunName, "S_regcppop") == 0) {
-          //     return 250;
-          // }
-          // if (std::strcmp(CalleeFunName, "S_regcppush") == 0) {
-          //     return 250;
-          // }
-          if (//usesEntryExit(Callee, GV) >= 5 &&
-              Callee->getInstructionCount() < 250 &&
-
-              //              (CalleeStores >= 5 && CalleeLoads >= 5) && 
-              (CalleeStores + CalleeLoads) > 10 &&
-
-              // CallerLoads > 25)
-              (CallerStores + CallerLoads) > 10)
-            return 500;
-
-          //  if 
-          // if ((CallerStores + CallerLoads) > 25)
-          //                 if (CallerLoads) > 25)
-
-          //}
-        }
-      }
-  }
+  const Function *Caller = CB->getParent()->getParent();
+  const Function *Callee = CB->getCalledFunction();
+  if (!Callee)
+    return 0;
+  const Module *M = Caller->getParent();
 
   // Increase the threshold if an incoming argument is used only as a memcpy
   // source.
-  if (Function *Callee = CB->getCalledFunction())
-    for (Argument &Arg : Callee->args()) {
-      bool OtherUse = false;
-      if (isUsedAsMemCpySource(&Arg, OtherUse) && !OtherUse)
-        Bonus += 1000;
+  for (const Argument &Arg : Callee->args()) {
+    bool OtherUse = false;
+    if (isUsedAsMemCpySource(&Arg, OtherUse) && !OtherUse) {
+      Bonus = 1000;
+      break;
     }
+  }
 
-  if (!Bonus) {
-    if (Function *Callee = CB->getCalledFunction()) {
-      unsigned NumStores = 0;
-      unsigned NumLoads = 0;
-      for (unsigned OpIdx = 0; OpIdx != Callee->arg_size(); ++OpIdx) {
-        Value    *CallerArg = CB->getArgOperand(OpIdx);
-        Argument *CalleeArg = Callee->getArg(OpIdx);
-        if (isa<AllocaInst>(CallerArg))
-          countNumMemAccesses(CalleeArg, NumStores, NumLoads);
+  // Give bonus for globals used much in both caller and callee.
+  std::set<const GlobalVariable *> CalleeGlobals;
+  std::set<const GlobalVariable *> CallerGlobals;
+  for (const GlobalVariable &Global : M->globals())
+    for (const User *U : Global.users())
+      if (const Instruction *User = dyn_cast<Instruction>(U)) {
+        if (User->getParent()->getParent() == Callee)
+          CalleeGlobals.insert(&Global);
+        if (User->getParent()->getParent() == Caller)
+          CallerGlobals.insert(&Global);
+      }
+  for (auto *GV : CalleeGlobals)
+    if (CallerGlobals.count(GV)) {
+      unsigned CalleeStores = 0, CalleeLoads = 0;
+      unsigned CallerStores = 0, CallerLoads = 0;
+      countNumMemAccesses(GV, CalleeStores, CalleeLoads, Callee);
+      countNumMemAccesses(GV, CallerStores, CallerLoads, Caller);
+      if ((CalleeStores + CalleeLoads) > 10 &&
+          (CallerStores + CallerLoads) > 10) {
+        Bonus = 1000;
+        break;
       }
-      //      dbgs() << "NUM: " << NumStores << " " << NumLoads << "\n";
-      // Best on povray, but not doing stores slightly better on blender.
-      if (NumLoads > 10)
-        Bonus += NumLoads * 50;
-      if (NumStores > 10)
-        Bonus += NumStores * 50;
-      Bonus = std::min(Bonus, unsigned(1000));
     }
+
+  // Give bonus when Callee accesses an Alloca of Caller heavily.
+  unsigned NumStores = 0;
+  unsigned NumLoads = 0;
+  for (unsigned OpIdx = 0; OpIdx != Callee->arg_size(); ++OpIdx) {
+    Value    *CallerArg = CB->getArgOperand(OpIdx);
+    Argument *CalleeArg = Callee->getArg(OpIdx);
+    if (isa<AllocaInst>(CallerArg))
+      countNumMemAccesses(CalleeArg, NumStores, NumLoads, Callee);
   }
+  if (NumLoads > 10)
+    Bonus += NumLoads * 50;
+  if (NumStores > 10)
+    Bonus += NumStores * 50;
+  Bonus = std::min(Bonus, unsigned(1000));
 
   LLVM_DEBUG(if (Bonus)
                dbgs() << "++ SZTTI Adding inlining bonus: " << Bonus << "\n";);
diff --git a/llvm/test/CodeGen/SystemZ/inline-thresh-adjust.ll b/llvm/test/CodeGen/SystemZ/inline-thresh-adjust.ll
index fbcfffa0bb7719..f7c83c7af7021b 100644
--- a/llvm/test/CodeGen/SystemZ/inline-thresh-adjust.ll
+++ b/llvm/test/CodeGen/SystemZ/inline-thresh-adjust.ll
@@ -1,13 +1,13 @@
 ; RUN: opt < %s -mtriple=systemz-unknown -mcpu=z15 -passes='cgscc(inline)' -disable-output \
 ; RUN:   -debug-only=inline,systemztti 2>&1 | FileCheck %s
 ; REQUIRES: asserts
-;
+
 ; Check that the inlining threshold is incremented for a function using an
 ; argument only as a memcpy source.
-
+;
 ; CHECK: Inlining calls in: root_function
 ; CHECK:     Inlining {{.*}} Call:   call void @leaf_function_A(ptr %Dst)
-; CHECK:     ++ SZTTI Adding inlining bonus: 150
+; CHECK:     ++ SZTTI Adding inlining bonus: 1000
 ; CHECK:     Inlining {{.*}} Call:   call void @leaf_function_B(ptr %Dst, ptr %Src)
 
 define void @leaf_function_A(ptr %Dst)  {
@@ -30,3 +30,136 @@ entry:
 }
 
 declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+
+; Check that the inlining threshold is incremented in case of multiple
+; accesses of a global variable by both caller and callee (which is true here
+; after the first call is inlined).
+;
+; CHECK: Inlining calls in: Caller1
+; CHECK: ++ SZTTI Adding inlining bonus: 1000
+
+ at GlobV = external global i32
+
+define i64 @Caller1(i1 %cond1, i32 %0) #0 {
+entry:
+  br i1 %cond1, label %sw.bb3437, label %fake_end
+
+common.ret:                                       ; preds = %fake_end, %sw.bb3437
+  ret i64 0
+
+sw.bb3437:                                        ; preds = %entry
+  %call34652 = call i32 @Callee1(ptr null, i32 %0)
+  br label %common.ret
+
+fake_end:                                         ; preds = %entry
+  %call57981 = call i32 @Callee1(ptr null, i32 0)
+  br label %common.ret
+}
+
+define i32 @Callee1(ptr %rex, i32 %parenfloor) #0 {
+entry:
+  %cmp21 = icmp slt i32 %parenfloor, 0
+  br i1 %cmp21, label %for.body, label %for.end
+
+common.ret:                                       ; preds = %for.end, %for.body
+  ret i32 0
+
+for.body:                                         ; preds = %entry
+  %0 = load i32, ptr @GlobV, align 4
+  %inc = or i32 %0, 1
+  store i32 %inc, ptr @GlobV, align 4
+  store i64 0, ptr %rex, align 8
+  %1 = load i32, ptr @GlobV, align 4
+  %inc28 = or i32 %1, 1
+  store i32 %inc28, ptr @GlobV, align 4
+  store i64 0, ptr %rex, align 8
+  %2 = load i32, ptr @GlobV, align 4
+  %inc35 = or i32 %2, 1
+  store i32 %inc35, ptr @GlobV, align 4
+  store i32 0, ptr %rex, align 8
+  br label %common.ret
+
+for.end:                                          ; preds = %entry
+  store i32 0, ptr @GlobV, align 4
+  store i32 0, ptr %rex, align 8
+  %3 = load i32, ptr @GlobV, align 4
+  %inc42 = or i32 %3, 1
+  store i32 %inc42, ptr @GlobV, align 4
+  store i32 0, ptr %rex, align 8
+  %4 = load i32, ptr @GlobV, align 4
+  %inc48 = or i32 %4, 1
+  store i32 %inc48, ptr @GlobV, align 4
+  br label %common.ret
+}
+
+; Check that the inlining threshold is incremented for a function that is
+; accessing an alloca of the caller multiple times.
+;
+; CHECK: Inlining calls in: Caller2
+; CHECK: ++ SZTTI Adding inlining bonus: 550
+
+define i1 @Caller2() {
+entry:
+  %A = alloca [80 x i64], align 8
+  call void @Callee2(ptr %A)
+  ret i1 false
+}
+
+define void @Callee2(ptr nocapture readonly %Arg) {
+entry:
+  %nonzero = getelementptr i8, ptr %Arg, i64 48
+  %0 = load i32, ptr %nonzero, align 8
+  %tobool1.not = icmp eq i32 %0, 0
+  br i1 %tobool1.not, label %if.else38, label %if.then2
+
+if.then2:                                         ; preds = %entry
+  %1 = load i32, ptr %Arg, align 4
+  %tobool4.not = icmp eq i32 %1, 0
+  br i1 %tobool4.not, label %common.ret, label %if.then5
+
+if.then5:                                         ; preds = %if.then2
+  %2 = load double, ptr %Arg, align 8
+  %slab_den = getelementptr i8, ptr %Arg, i64 24
+  %3 = load double, ptr %slab_den, align 8
+  %mul = fmul double %2, %3
+  %cmp = fcmp olt double %mul, 0.000000e+00
+  br i1 %cmp, label %common.ret, label %if.end55
+
+common.ret:                                       ; preds = %if.end100, %if.else79, %if.end55, %if.else38, %if.then5, %if.then2
+  ret void
+
+if.else38:                                        ; preds = %entry
+  %4 = load double, ptr %Arg, align 8
+  %cmp52 = fcmp ogt double %4, 0.000000e+00
+  br i1 %cmp52, label %common.ret, label %if.end55
+
+if.end55:                                         ; preds = %if.else38, %if.then5
+  %arrayidx57 = getelementptr i8, ptr %Arg, i64 52
+  %5 = load i32, ptr %arrayidx57, align 4
+  %tobool58.not = icmp eq i32 %5, 0
+  br i1 %tobool58.not, label %common.ret, label %if.then59
+
+if.then59:                                        ; preds = %if.end55
+  %arrayidx61 = getelementptr i8, ptr %Arg, i64 64
+  %6 = load i32, ptr %arrayidx61, align 4
+  %tobool62.not = icmp eq i32 %6, 0
+  br i1 %tobool62.not, label %if.else79, label %if.end100
+
+if.else79:                                        ; preds = %if.then59
+  %arrayidx84 = getelementptr i8, ptr %Arg, i64 8
+  %7 = load double, ptr %arrayidx84, align 8
+  %arrayidx87 = getelementptr i8, ptr %Arg, i64 32
+  %8 = load double, ptr %arrayidx87, align 8
+  %mul88 = fmul double %7, %8
+  %9 = fcmp olt double %mul88, 0.000000e+00
+  br i1 %9, label %common.ret, label %if.end100
+
+if.end100:                                        ; preds = %if.else79, %if.then59
+  %arrayidx151 = getelementptr i8, ptr %Arg, i64 16
+  %10 = load double, ptr %arrayidx151, align 8
+  %arrayidx154 = getelementptr i8, ptr %Arg, i64 40
+  %11 = load double, ptr %arrayidx154, align 8
+  %mul155 = fmul double %10, %11
+  %cmp181 = fcmp olt double %mul155, 0.000000e+00
+  br label %common.ret
+}

>From 9c72e204ea4e2543914d668247630c4d47008da6 Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Wed, 11 Sep 2024 15:22:50 +0200
Subject: [PATCH 3/4] Rebase

---
 llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp | 1 -
 llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h   | 1 -
 2 files changed, 2 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index b98f11d878aebb..27159e099d7d6c 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -135,7 +135,6 @@ unsigned SystemZTTIImpl::adjustInliningThreshold(const CallBase *CB) const {
 
   LLVM_DEBUG(if (Bonus)
                dbgs() << "++ SZTTI Adding inlining bonus: " << Bonus << "\n";);
-
   return Bonus;
 }
 
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index 8e5df4ee270020..8cc71a6c528f82 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -38,7 +38,6 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
   /// \name Scalar TTI Implementations
   /// @{
 
-  unsigned getInliningThresholdMultiplier() const { return 1; }
   unsigned adjustInliningThreshold(const CallBase *CB) const;
 
   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,

>From 8294ff89550b04791f2f49923ba9d5d6c93a4a30 Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Fri, 20 Sep 2024 16:29:49 +0200
Subject: [PATCH 4/4] Code formatting

---
 llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 27159e099d7d6c..7e5728c40950ad 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -63,12 +63,10 @@ static void countNumMemAccesses(const Value *Ptr, unsigned &NumStores,
         if (const auto *SI = dyn_cast<StoreInst>(User)) {
           if (SI->getPointerOperand() == Ptr && !SI->isVolatile())
             NumStores++;
-        }
-        else if (const auto *LI = dyn_cast<LoadInst>(User)) {
+        } else if (const auto *LI = dyn_cast<LoadInst>(User)) {
           if (LI->getPointerOperand() == Ptr && !LI->isVolatile())
             NumLoads++;
-        }
-        else if (const auto *GEP = dyn_cast<GetElementPtrInst>(User)) {
+        } else if (const auto *GEP = dyn_cast<GetElementPtrInst>(User)) {
           if (GEP->getPointerOperand() == Ptr)
             countNumMemAccesses(GEP, NumStores, NumLoads, F);
         }
@@ -122,7 +120,7 @@ unsigned SystemZTTIImpl::adjustInliningThreshold(const CallBase *CB) const {
   unsigned NumStores = 0;
   unsigned NumLoads = 0;
   for (unsigned OpIdx = 0; OpIdx != Callee->arg_size(); ++OpIdx) {
-    Value    *CallerArg = CB->getArgOperand(OpIdx);
+    Value *CallerArg = CB->getArgOperand(OpIdx);
     Argument *CalleeArg = Callee->getArg(OpIdx);
     if (isa<AllocaInst>(CallerArg))
       countNumMemAccesses(CalleeArg, NumStores, NumLoads, Callee);



More information about the llvm-commits mailing list