[llvm] [memprof] Add simplify_type (NFC) (PR #123556)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 19 23:52:51 PST 2025


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/123556

IndexCall is a simple wrapper around:

  PointerUnion<CallsiteInfo *, AllocInfo *>

Now, because we don't have CastInfo for IndexCall, we would have to
use getBase like so:

  dyn_cast_if_present<CallsiteInfo *>(Call.getBase())

This patch adds simplify_type<IndexCall>, which in turn enables
CastInfo for IndexCall, so we can drop getBase like so::

  dyn_cast_if_present<CallsiteInfo *>(Call)


>From 1f5dd34641ad6615d7d43d212da38f18f655e4da Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Sun, 19 Jan 2025 18:44:17 -0800
Subject: [PATCH] [memprof] Add simplify_type (NFC)

IndexCall is a simple wrapper around:

  PointerUnion<CallsiteInfo *, AllocInfo *>

Now, because we don't have CastInfo for IndexCall, we would have to
use getBase like so:

  dyn_cast_if_present<CallsiteInfo *>(Call.getBase())

This patch adds simplify_type<IndexCall>, which in turn enables
CastInfo for IndexCall, so we can drop getBase like so::

  dyn_cast_if_present<CallsiteInfo *>(Call)
---
 .../IPO/MemProfContextDisambiguation.cpp      | 46 +++++++++++--------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index 61a8f4a448bbd7..988e912b2de838 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -821,19 +821,31 @@ struct IndexCall : public PointerUnion<CallsiteInfo *, AllocInfo *> {
 
   IndexCall *operator->() { return this; }
 
-  PointerUnion<CallsiteInfo *, AllocInfo *> getBase() const { return *this; }
-
   void print(raw_ostream &OS) const {
-    if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(getBase())) {
+    PointerUnion<CallsiteInfo *, AllocInfo *> Base = *this;
+    if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(Base)) {
       OS << *AI;
     } else {
-      auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(getBase());
+      auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(Base);
       assert(CI);
       OS << *CI;
     }
   }
 };
+} // namespace
+
+namespace llvm {
+template <> struct simplify_type<IndexCall> {
+  using SimpleType = PointerUnion<CallsiteInfo *, AllocInfo *>;
+  static SimpleType getSimplifiedValue(IndexCall &Val) { return Val; }
+};
+template <> struct simplify_type<const IndexCall> {
+  using SimpleType = const PointerUnion<CallsiteInfo *, AllocInfo *>;
+  static SimpleType getSimplifiedValue(const IndexCall &Val) { return Val; }
+};
+} // namespace llvm
 
+namespace {
 /// CRTP derived class for graphs built from summary index (ThinLTO).
 class IndexCallsiteContextGraph
     : public CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
@@ -1877,9 +1889,9 @@ uint64_t ModuleCallsiteContextGraph::getLastStackId(Instruction *Call) {
 }
 
 uint64_t IndexCallsiteContextGraph::getLastStackId(IndexCall &Call) {
-  assert(isa<CallsiteInfo *>(Call.getBase()));
+  assert(isa<CallsiteInfo *>(Call));
   CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
-      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call));
   // Need to convert index into stack id.
   return Index.getStackIdAtIndex(CallsiteContext.back());
 }
@@ -1911,10 +1923,10 @@ std::string IndexCallsiteContextGraph::getLabel(const FunctionSummary *Func,
                                                 unsigned CloneNo) const {
   auto VI = FSToVIMap.find(Func);
   assert(VI != FSToVIMap.end());
-  if (isa<AllocInfo *>(Call.getBase()))
+  if (isa<AllocInfo *>(Call))
     return (VI->second.name() + " -> alloc").str();
   else {
-    auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call.getBase());
+    auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call);
     return (VI->second.name() + " -> " +
             getMemProfFuncName(Callsite->Callee.name(),
                                Callsite->Clones[CloneNo]))
@@ -1933,9 +1945,9 @@ ModuleCallsiteContextGraph::getStackIdsWithContextNodesForCall(
 
 std::vector<uint64_t>
 IndexCallsiteContextGraph::getStackIdsWithContextNodesForCall(IndexCall &Call) {
-  assert(isa<CallsiteInfo *>(Call.getBase()));
+  assert(isa<CallsiteInfo *>(Call));
   CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator>
-      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase()));
+      CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call));
   return getStackIdsWithContextNodes<CallsiteInfo,
                                      SmallVector<unsigned>::const_iterator>(
       CallsiteContext);
@@ -2696,8 +2708,7 @@ bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
 
 const FunctionSummary *
 IndexCallsiteContextGraph::getCalleeFunc(IndexCall &Call) {
-  ValueInfo Callee =
-      dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
+  ValueInfo Callee = dyn_cast_if_present<CallsiteInfo *>(Call)->Callee;
   if (Callee.getSummaryList().empty())
     return nullptr;
   return dyn_cast<FunctionSummary>(Callee.getSummaryList()[0]->getBaseObject());
@@ -2707,8 +2718,7 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc(
     IndexCall &Call, const FunctionSummary *Func,
     const FunctionSummary *CallerFunc,
     std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain) {
-  ValueInfo Callee =
-      dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
+  ValueInfo Callee = dyn_cast_if_present<CallsiteInfo *>(Call)->Callee;
   // If there is no summary list then this is a call to an externally defined
   // symbol.
   AliasSummary *Alias =
@@ -2751,10 +2761,8 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc(
 }
 
 bool IndexCallsiteContextGraph::sameCallee(IndexCall &Call1, IndexCall &Call2) {
-  ValueInfo Callee1 =
-      dyn_cast_if_present<CallsiteInfo *>(Call1.getBase())->Callee;
-  ValueInfo Callee2 =
-      dyn_cast_if_present<CallsiteInfo *>(Call2.getBase())->Callee;
+  ValueInfo Callee1 = dyn_cast_if_present<CallsiteInfo *>(Call1)->Callee;
+  ValueInfo Callee2 = dyn_cast_if_present<CallsiteInfo *>(Call2)->Callee;
   return Callee1 == Callee2;
 }
 
@@ -3610,7 +3618,7 @@ IndexCallsiteContextGraph::cloneFunctionForCallsite(
   // Confirm this matches the CloneNo provided by the caller, which is based on
   // the number of function clones we have.
   assert(CloneNo ==
-         (isa<AllocInfo *>(Call.call().getBase())
+         (isa<AllocInfo *>(Call.call())
               ? Call.call().dyn_cast<AllocInfo *>()->Versions.size()
               : Call.call().dyn_cast<CallsiteInfo *>()->Clones.size()));
   // Walk all the instructions in this function. Create a new version for



More information about the llvm-commits mailing list