[llvm] [memprof] Add simplify_type (NFC) (PR #123556)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 19 23:53:25 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Kazu Hirata (kazutakahirata)

<details>
<summary>Changes</summary>

IndexCall is a simple wrapper around:

  PointerUnion<CallsiteInfo *, AllocInfo *>

Now, because we don't have CastInfo for IndexCall, we would have to
use getBase like so:

  dyn_cast_if_present<CallsiteInfo *>(Call.getBase())

This patch adds simplify_type<IndexCall>, which in turn enables
CastInfo for IndexCall, so we can drop getBase like so::

  dyn_cast_if_present<CallsiteInfo *>(Call)


---
Full diff: https://github.com/llvm/llvm-project/pull/123556.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp (+27-19) 


``````````diff
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index 61a8f4a448bbd7..988e912b2de838 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -821,19 +821,31 @@ struct IndexCall : public PointerUnion<CallsiteInfo *, AllocInfo *> {
 
   IndexCall *operator->() { return this; }
 
-  PointerUnion<CallsiteInfo *, AllocInfo *> getBase() const { return *this; }
-
   void print(raw_ostream &OS) const {
-    if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(getBase())) {
+    PointerUnion<CallsiteInfo *, AllocInfo *> Base = *this;
+    if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(Base)) {
       OS << *AI;
     } else {
-      auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(getBase());
+      auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(Base);
       assert(CI);
       OS << *CI;
     }
   }
 };
+} // namespace
+
+namespace llvm {
+template <> struct simplify_type<IndexCall> {
+  using SimpleType = PointerUnion<CallsiteInfo *, AllocInfo *>;
+  static SimpleType getSimplifiedValue(IndexCall &Val) { return Val; }
+};
+template <> struct simplify_type<const IndexCall> {
+  using SimpleType = const PointerUnion<CallsiteInfo *, AllocInfo *>;
+  static SimpleType getSimplifiedValue(const IndexCall &Val) { return Val; }
+};
+} // namespace llvm
 
+namespace {
 /// CRTP derived class for graphs built from summary index (ThinLTO).
 class IndexCallsiteContextGraph
     : public CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
@@ -1877,9 +1889,9 @@ uint64_t ModuleCallsiteContextGraph::getLastStackId(Instruction *Call) {
 }
 
 uint64_t IndexCallsiteContextGraph::getLastStackId(IndexCall &Call) {
-  assert(isa<CallsiteInfo *>(Call.getBase()));
+  assert(isa<CallsiteInfo *>(Call));
   CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
-      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call));
   // Need to convert index into stack id.
   return Index.getStackIdAtIndex(CallsiteContext.back());
 }
@@ -1911,10 +1923,10 @@ std::string IndexCallsiteContextGraph::getLabel(const FunctionSummary *Func,
                                                 unsigned CloneNo) const {
   auto VI = FSToVIMap.find(Func);
   assert(VI != FSToVIMap.end());
-  if (isa<AllocInfo *>(Call.getBase()))
+  if (isa<AllocInfo *>(Call))
     return (VI->second.name() + " -> alloc").str();
   else {
-    auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call.getBase());
+    auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call);
     return (VI->second.name() + " -> " +
             getMemProfFuncName(Callsite->Callee.name(),
                                Callsite->Clones[CloneNo]))
@@ -1933,9 +1945,9 @@ ModuleCallsiteContextGraph::getStackIdsWithContextNodesForCall(
 
 std::vector<uint64_t>
 IndexCallsiteContextGraph::getStackIdsWithContextNodesForCall(IndexCall &Call) {
-  assert(isa<CallsiteInfo *>(Call.getBase()));
+  assert(isa<CallsiteInfo *>(Call));
   CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
-      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call));
   return getStackIdsWithContextNodes<CallsiteInfo,
                                      SmallVector<unsigned>::const_iterator>(
       CallsiteContext);
@@ -2696,8 +2708,7 @@ bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
 
 const FunctionSummary *
 IndexCallsiteContextGraph::getCalleeFunc(IndexCall &Call) {
-  ValueInfo Callee =
-      dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
+  ValueInfo Callee = dyn_cast_if_present<CallsiteInfo *>(Call)->Callee;
   if (Callee.getSummaryList().empty())
     return nullptr;
   return dyn_cast<FunctionSummary>(Callee.getSummaryList()[0]->getBaseObject());
@@ -2707,8 +2718,7 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc(
     IndexCall &Call, const FunctionSummary *Func,
     const FunctionSummary *CallerFunc,
     std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain) {
-  ValueInfo Callee =
-      dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
+  ValueInfo Callee = dyn_cast_if_present<CallsiteInfo *>(Call)->Callee;
   // If there is no summary list then this is a call to an externally defined
   // symbol.
   AliasSummary *Alias =
@@ -2751,10 +2761,8 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc(
 }
 
 bool IndexCallsiteContextGraph::sameCallee(IndexCall &Call1, IndexCall &Call2) {
-  ValueInfo Callee1 =
-      dyn_cast_if_present<CallsiteInfo *>(Call1.getBase())->Callee;
-  ValueInfo Callee2 =
-      dyn_cast_if_present<CallsiteInfo *>(Call2.getBase())->Callee;
+  ValueInfo Callee1 = dyn_cast_if_present<CallsiteInfo *>(Call1)->Callee;
+  ValueInfo Callee2 = dyn_cast_if_present<CallsiteInfo *>(Call2)->Callee;
   return Callee1 == Callee2;
 }
 
@@ -3610,7 +3618,7 @@ IndexCallsiteContextGraph::cloneFunctionForCallsite(
   // Confirm this matches the CloneNo provided by the caller, which is based on
   // the number of function clones we have.
   assert(CloneNo ==
-         (isa<AllocInfo *>(Call.call().getBase())
+         (isa<AllocInfo *>(Call.call())
               ? Call.call().dyn_cast<AllocInfo *>()->Versions.size()
               : Call.call().dyn_cast<CallsiteInfo *>()->Clones.size()));
   // Walk all the instructions in this function. Create a new version for

``````````

</details>


https://github.com/llvm/llvm-project/pull/123556


More information about the llvm-commits mailing list