[llvm] [memprof] Add computeUndriftMap (PR #116478)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 16 00:42:38 PST 2024


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/116478

This patch adds computeUndriftMap, a function to compute mappings from
source locations in the MemProf profile to source locations in the IR.


>From 9463fcfc8562a6837022dbecde58f9a5d931ce32 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Thu, 14 Nov 2024 13:20:29 -0800
Subject: [PATCH] [memprof] Add computeUndriftMap

This patch adds computeUndriftMap, a function to compute mappings from
source locations in the MemProf profile to source locations in the IR.
---
 .../Transforms/Instrumentation/MemProfiler.h  |  16 ++
 .../Instrumentation/MemProfiler.cpp           |  30 +++
 .../Instrumentation/MemProfUseTest.cpp        | 191 ++++++++++++++++++
 3 files changed, 237 insertions(+)

diff --git a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
index a197a2687ed029..052b346e527573 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
@@ -18,6 +18,7 @@
 
 namespace llvm {
 class Function;
+class IndexedInstrProfReader;
 class Module;
 class TargetLibraryInfo;
 
@@ -66,6 +67,21 @@ namespace memprof {
 DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>>
 extractCallsFromIR(Module &M, const TargetLibraryInfo &TLI);
 
+struct LineLocationHash {
+  uint64_t operator()(const LineLocation &Loc) const {
+    return Loc.getHashCode();
+  }
+};
+
+using LocToLocMap =
+    std::unordered_map<LineLocation, LineLocation, LineLocationHash>;
+
+// Compute an undrifting map.  The result is a map from caller GUIDs to an inner
+// map that maps source locations in the profile to those in the current IR.
+DenseMap<uint64_t, LocToLocMap>
+computeUndriftMap(Module &M, IndexedInstrProfReader *MemProfReader,
+                  const TargetLibraryInfo &TLI);
+
 } // namespace memprof
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index d59e0d26487d4f..c4321b14793882 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -42,6 +42,7 @@
 #include "llvm/Support/VirtualFileSystem.h"
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/LongestCommonSequence.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include <map>
 #include <set>
@@ -856,6 +857,35 @@ memprof::extractCallsFromIR(Module &M, const TargetLibraryInfo &TLI) {
   return Calls;
 }
 
+DenseMap<uint64_t, LocToLocMap>
+memprof::computeUndriftMap(Module &M, IndexedInstrProfReader *MemProfReader,
+                           const TargetLibraryInfo &TLI) {
+  DenseMap<uint64_t, LocToLocMap> UndriftMaps;
+
+  auto CallsFromProfile = MemProfReader->getMemProfCallerCalleePairs();
+  auto CallsFromIR = extractCallsFromIR(M, TLI);
+
+  // Compute an undrift map for each CallerGUID.
+  for (const auto &[CallerGUID, IRAnchors] : CallsFromIR) {
+    auto It = CallsFromProfile.find(CallerGUID);
+    if (It == CallsFromProfile.end())
+      continue;
+    const auto &ProfileAnchors = It->second;
+
+    LocToLocMap Matchings;
+    longestCommonSequence<LineLocation, GlobalValue::GUID>(
+        ProfileAnchors, IRAnchors, std::equal_to<GlobalValue::GUID>(),
+        [&](LineLocation A, LineLocation B) { Matchings.try_emplace(A, B); });
+    bool Inserted = UndriftMaps.try_emplace(CallerGUID, Matchings).second;
+
+    // The insertion must succeed because we visit each GUID exactly once.
+    assert(Inserted);
+    (void)Inserted;
+  }
+
+  return UndriftMaps;
+}
+
 static void
 readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
             const TargetLibraryInfo &TLI,
diff --git a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
index cd0e8357a2b2da..944bb4b2a06aff 100644
--- a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
+++ b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
@@ -11,8 +11,11 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProfReader.h"
+#include "llvm/ProfileData/InstrProfWriter.h"
 #include "llvm/ProfileData/MemProf.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/Error.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
 
 #include "gmock/gmock.h"
@@ -21,9 +24,12 @@
 namespace {
 using namespace llvm;
 using namespace llvm::memprof;
+using testing::Contains;
+using testing::ElementsAre;
 using testing::FieldsAre;
 using testing::Pair;
 using testing::SizeIs;
+using testing::UnorderedElementsAre;
 
 TEST(MemProf, ExtractDirectCallsFromIR) {
   // The following IR is generated from:
@@ -298,4 +304,189 @@ attributes #2 = { builtin allocsize(0) }
   ASSERT_THAT(FooCallSites, SizeIs(1));
   EXPECT_THAT(FooCallSites[0], Pair(FieldsAre(1U, 10U), 0));
 }
+
+// Populate those fields returned by getHotColdSchema.
+MemInfoBlock makePartialMIB() {
+  MemInfoBlock MIB;
+  MIB.AllocCount = 1;
+  MIB.TotalSize = 5;
+  MIB.TotalLifetime = 10;
+  MIB.TotalLifetimeAccessDensity = 23;
+  return MIB;
+}
+
+IndexedMemProfRecord
+makeRecordV2(std::initializer_list<::llvm::memprof::CallStackId> AllocFrames,
+             std::initializer_list<::llvm::memprof::CallStackId> CallSiteFrames,
+             const MemInfoBlock &Block, const memprof::MemProfSchema &Schema) {
+  llvm::memprof::IndexedMemProfRecord MR;
+  for (const auto &CSId : AllocFrames)
+    // We don't populate IndexedAllocationInfo::CallStack because we use it only
+    // in Version1.
+    MR.AllocSites.emplace_back(::llvm::SmallVector<memprof::FrameId>(), CSId,
+                               Block, Schema);
+  for (const auto &CSId : CallSiteFrames)
+    MR.CallSiteIds.push_back(CSId);
+  return MR;
+}
+
+static const auto Err = [](Error E) {
+  consumeError(std::move(E));
+  FAIL();
+};
+
+// Make sure that we can undrift direct calls.
+TEST(MemProf, ComputeUndriftingMap) {
+  // Suppose that the source code has changed from:
+  //
+  //   void bar();
+  //   void baz();
+  //   void zzz();
+  //
+  //   void foo() {
+  //     /**/ bar();  // LineLocation(1, 8)
+  //     zzz();       // LineLocation(2, 3)
+  //     baz();       // LineLocation(3, 3)
+  //   }
+  //
+  // to:
+  //
+  //   void bar();
+  //   void baz();
+  //
+  //   void foo() {
+  //     bar();        // LineLocation(1, 3)
+  //     /**/ baz();   // LineLocation(2, 8)
+  //   }
+  //
+  // Notice that the calls to bar and baz have drifted while zzz has been
+  // removed.
+  StringRef IR = R"IR(
+define dso_local void @_Z3foov() #0 !dbg !10 {
+entry:
+  call void @_Z3barv(), !dbg !13
+  call void @_Z3bazv(), !dbg !14
+  ret void, !dbg !15
+}
+
+declare !dbg !16 void @_Z3barv() #1
+
+declare !dbg !17 void @_Z3bazv() #1
+
+attributes #0 = { mustprogress uwtable "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+attributes #1 = { "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2, !3, !4, !5, !6, !7, !8}
+!llvm.ident = !{!9}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, splitDebugInlining: false, debugInfoForProfiling: true, nameTableKind: None)
+!1 = !DIFile(filename: "foobar.cc", directory: "/")
+!2 = !{i32 7, !"Dwarf Version", i32 5}
+!3 = !{i32 2, !"Debug Info Version", i32 3}
+!4 = !{i32 1, !"wchar_size", i32 4}
+!5 = !{i32 1, !"MemProfProfileFilename", !"memprof.profraw"}
+!6 = !{i32 8, !"PIC Level", i32 2}
+!7 = !{i32 7, !"PIE Level", i32 2}
+!8 = !{i32 7, !"uwtable", i32 2}
+!9 = !{!"clang"}
+!10 = distinct !DISubprogram(name: "foo", linkageName: "_Z3foov", scope: !1, file: !1, line: 4, type: !11, scopeLine: 4, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
+!11 = !DISubroutineType(types: !12)
+!12 = !{}
+!13 = !DILocation(line: 5, column: 3, scope: !10)
+!14 = !DILocation(line: 6, column: 8, scope: !10)
+!15 = !DILocation(line: 7, column: 1, scope: !10)
+!16 = !DISubprogram(name: "bar", linkageName: "_Z3barv", scope: !1, file: !1, line: 1, type: !11, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+!17 = !DISubprogram(name: "baz", linkageName: "_Z3bazv", scope: !1, file: !1, line: 2, type: !11, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+)IR";
+
+  LLVMContext Ctx;
+  SMDiagnostic SMErr;
+  std::unique_ptr<Module> M = parseAssemblyString(IR, SMErr, Ctx);
+  ASSERT_TRUE(M);
+
+  auto *F = M->getFunction("_Z3foov");
+  ASSERT_NE(F, nullptr);
+
+  TargetLibraryInfoWrapperPass WrapperPass;
+  auto &TLI = WrapperPass.getTLI(*F);
+  auto Calls = extractCallsFromIR(*M, TLI);
+
+  uint64_t GUIDFoo = IndexedMemProfRecord::getGUID("_Z3foov");
+  uint64_t GUIDBar = IndexedMemProfRecord::getGUID("_Z3barv");
+  uint64_t GUIDBaz = IndexedMemProfRecord::getGUID("_Z3bazv");
+  uint64_t GUIDZzz = IndexedMemProfRecord::getGUID("_Z3zzzv");
+
+  // Verify that extractCallsFromIR extracts caller-callee pairs as expected.
+  EXPECT_THAT(Calls,
+              UnorderedElementsAre(Pair(
+                  GUIDFoo, ElementsAre(Pair(LineLocation(1, 3), GUIDBar),
+                                       Pair(LineLocation(2, 8), GUIDBaz)))));
+
+  llvm::InstrProfWriter Writer;
+  std::unique_ptr<IndexedInstrProfReader> Reader;
+
+  const MemInfoBlock MIB = makePartialMIB();
+
+  Writer.setMemProfVersionRequested(memprof::Version3);
+  Writer.setMemProfFullSchema(false);
+
+  ASSERT_THAT_ERROR(Writer.mergeProfileKind(InstrProfKind::MemProf),
+                    Succeeded());
+
+  const std::pair<memprof::FrameId, memprof::Frame> Frames[] = {
+      // The call sites within foo.
+      {0, {GUIDFoo, 1, 8, false}},
+      {1, {GUIDFoo, 2, 3, false}},
+      {2, {GUIDFoo, 3, 3, false}},
+      // Line/column numbers below don't matter.
+      {3, {GUIDBar, 9, 9, false}},
+      {4, {GUIDZzz, 9, 9, false}},
+      {5, {GUIDBaz, 9, 9, false}}};
+  for (const auto &[FrameId, Frame] : Frames)
+    Writer.addMemProfFrame(FrameId, Frame, Err);
+
+  const std::pair<memprof::CallStackId, SmallVector<memprof::FrameId>>
+      CallStacks[] = {
+          {0x111, {3, 0}}, // bar called by foo
+          {0x222, {4, 1}}, // zzz called by foo
+          {0x333, {5, 2}}  // baz called by foo
+      };
+  for (const auto &[CSId, CallStack] : CallStacks)
+    Writer.addMemProfCallStack(CSId, CallStack, Err);
+
+  const IndexedMemProfRecord IndexedMR = makeRecordV2(
+      /*AllocFrames=*/{0x111, 0x222, 0x333},
+      /*CallSiteFrames=*/{}, MIB, memprof::getHotColdSchema());
+  Writer.addMemProfRecord(/*Id=*/0x9999, IndexedMR);
+
+  auto Profile = Writer.writeBuffer();
+
+  auto ReaderOrErr =
+      IndexedInstrProfReader::create(std::move(Profile), nullptr);
+  EXPECT_THAT_ERROR(ReaderOrErr.takeError(), Succeeded());
+  Reader = std::move(ReaderOrErr.get());
+
+  // Verify that getMemProfCallerCalleePairs extracts caller-callee pairs as
+  // expected.
+  auto Pairs = Reader->getMemProfCallerCalleePairs();
+  ASSERT_THAT(Pairs, SizeIs(4));
+  ASSERT_THAT(
+      Pairs,
+      Contains(Pair(GUIDFoo, ElementsAre(Pair(LineLocation(1, 8), GUIDBar),
+                                         Pair(LineLocation(2, 3), GUIDZzz),
+                                         Pair(LineLocation(3, 3), GUIDBaz)))));
+
+  // Verify that computeUndriftMap identifies undrifting opportunities:
+  //
+  //   Profile                 IR
+  //   (Line: 1, Column: 8) -> (Line: 1, Column: 3)
+  //   (Line: 3, Column: 3) -> (Line: 2, Column: 8)
+  auto UndriftMap = computeUndriftMap(*M, Reader.get(), TLI);
+  ASSERT_THAT(UndriftMap,
+              UnorderedElementsAre(Pair(
+                  GUIDFoo, UnorderedElementsAre(
+                               Pair(LineLocation(1, 8), LineLocation(1, 3)),
+                               Pair(LineLocation(3, 3), LineLocation(2, 8))))));
+}
 } // namespace



More information about the llvm-commits mailing list