[llvm] [memprof] Add extractCallsFromIR (PR #115218)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 12:02:49 PST 2024


https://github.com/kazutakahirata updated https://github.com/llvm/llvm-project/pull/115218

>From ed3c386813975d66a53005cb11b001d98b676ead Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Mon, 22 Jul 2024 16:28:04 -0700
Subject: [PATCH 1/3] [memprof] Add extractCallsFromIR

This patch adds extractCallsFromIR, a function to extract calls from
the IR, which will be used to undrift call site locations in the
MemProf profile.

In a nutshell, the MemProf undrifting works as follows:

- Extract call site locations from the IR.
- Extract call site locations from the MemProf profile.
- Undrift the call site locations with longestCommonSequence.

This patch implements the first bullet point above.  Specifically,
given the IR, the new function returns a map from caller GUIDs to
lists of corresponding call sites.  For example:

Given:

  foo() {
    f1();
    f2(); f3();
  }

extractCallsFromIR returns:

  Caller: foo ->
    {{(Line 1, Column 3), Callee: f1},
     {(Line 2, Column 3), Callee: f2},
     {(Line 2, Column 9), Callee: f3}}

where the line numbers, relative to the beginning of the caller, and
column numbers are sorted in the ascending order.  The value side of
the map -- the list of call sites -- can be directly passed to
longestCommonSequence.

To facilitate the review process, I've only implemented basic features
in extractCallsFromIR in this patch.

- The new function extracts calls from the LLVM "call" instructions
  only.  It does not look into the inline stack.
- It does not recognize or treat heap allocation functions in any
  special way.

I will address these missing features in subsequent patches.
---
 .../Transforms/Instrumentation/MemProfiler.h  |  35 ++++++
 .../Instrumentation/MemProfiler.cpp           |  47 ++++++++
 .../Transforms/Instrumentation/CMakeLists.txt |   1 +
 .../Instrumentation/MemProfUseTest.cpp        | 108 ++++++++++++++++++
 4 files changed, 191 insertions(+)
 create mode 100644 llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp

diff --git a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
index f92c6b4775a2a2..076a2785bbaa77 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
@@ -57,6 +57,41 @@ class MemProfUsePass : public PassInfoMixin<MemProfUsePass> {
   IntrusiveRefCntPtr<vfs::FileSystem> FS;
 };
 
+namespace memprof {
+
+struct LineLocation {
+  LineLocation(uint32_t L, uint32_t D) : LineOffset(L), Column(D) {}
+
+  void print(raw_ostream &OS) const;
+  void dump() const;
+
+  bool operator<(const LineLocation &O) const {
+    return LineOffset < O.LineOffset ||
+           (LineOffset == O.LineOffset && Column < O.Column);
+  }
+
+  bool operator==(const LineLocation &O) const {
+    return LineOffset == O.LineOffset && Column == O.Column;
+  }
+
+  bool operator!=(const LineLocation &O) const {
+    return LineOffset != O.LineOffset || Column != O.Column;
+  }
+
+  uint64_t getHashCode() const { return ((uint64_t)Column << 32) | LineOffset; }
+
+  uint32_t LineOffset;
+  uint32_t Column;
+};
+
+// A pair of a call site location and its corresponding callee GUID.
+using CallEdgeTy = std::pair<LineLocation, uint64_t>;
+
+// Extract all calls from the IR.  Arrange them in a map from caller GUIDs to a
+// list of call sites, each of the form {LineLocation, CalleeGUID}.
+DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>> extractCallsFromIR(Module &M);
+
+} // namespace memprof
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index 70bee30fd151f6..fef11d9ffe306f 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -795,6 +795,53 @@ struct AllocMatchInfo {
   bool Matched = false;
 };
 
+DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>>
+memprof::extractCallsFromIR(Module &M) {
+  DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>> Calls;
+
+  auto GetOffset = [](const DILocation *DIL) {
+    return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
+           0xffff;
+  };
+
+  for (Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    for (auto &BB : F) {
+      for (auto &I : BB) {
+        const DILocation *DIL = I.getDebugLoc();
+        if (!DIL)
+          continue;
+
+        if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
+          continue;
+
+        auto *CB = dyn_cast<CallBase>(&I);
+        auto *CalledFunction = CB->getCalledFunction();
+        if (!CalledFunction || CalledFunction->isIntrinsic())
+          continue;
+
+        StringRef CalleeName = CalledFunction->getName();
+        uint64_t CallerGUID =
+            IndexedMemProfRecord::getGUID(DIL->getSubprogramLinkageName());
+        uint64_t CalleeGUID = IndexedMemProfRecord::getGUID(CalleeName);
+        LineLocation Loc = {GetOffset(DIL), DIL->getColumn()};
+        Calls[CallerGUID].emplace_back(Loc, CalleeGUID);
+      }
+    }
+  }
+
+  // Sort each call list by the source location.
+  for (auto &KV : Calls) {
+    auto &Calls = KV.second;
+    llvm::sort(Calls);
+    Calls.erase(llvm::unique(Calls), Calls.end());
+  }
+
+  return Calls;
+}
+
 static void
 readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
             const TargetLibraryInfo &TLI,
diff --git a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
index 1f249b0049d062..80fac2353be416 100644
--- a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 )
 
 add_llvm_unittest(InstrumentationTests
+  MemProfUseTest.cpp
   PGOInstrumentationTest.cpp
   )
 
diff --git a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
new file mode 100644
index 00000000000000..21c7537852c4df
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
@@ -0,0 +1,108 @@
+//===- MemProfUseTest.cpp - MemProf use tests -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/ProfileData/MemProf.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Instrumentation/MemProfiler.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace {
+using namespace llvm;
+
+TEST(MemProf, ExtractDirectCallsFromIR) {
+  // The following IR is generated from:
+  //
+  // void f1();
+  // void f2();
+  // void f3();
+  //
+  // void foo() {
+  //   f1();
+  //   f2(); f3();
+  // }
+  StringRef IR = R"IR(
+define dso_local void @_Z3foov() !dbg !10 {
+entry:
+  call void @_Z2f1v(), !dbg !13
+  call void @_Z2f2v(), !dbg !14
+  call void @_Z2f3v(), !dbg !15
+  ret void, !dbg !16
+}
+
+declare !dbg !17 void @_Z2f1v()
+
+declare !dbg !18 void @_Z2f2v()
+
+declare !dbg !19 void @_Z2f3v()
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2, !3, !4, !5, !6, !7, !8}
+!llvm.ident = !{!9}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, splitDebugInlining: false, debugInfoForProfiling: true, nameTableKind: None)
+!1 = !DIFile(filename: "foobar.cc", directory: "/")
+!2 = !{i32 7, !"Dwarf Version", i32 5}
+!3 = !{i32 2, !"Debug Info Version", i32 3}
+!4 = !{i32 1, !"wchar_size", i32 4}
+!5 = !{i32 1, !"MemProfProfileFilename", !"memprof.profraw"}
+!6 = !{i32 8, !"PIC Level", i32 2}
+!7 = !{i32 7, !"PIE Level", i32 2}
+!8 = !{i32 7, !"uwtable", i32 2}
+!9 = !{!"clang"}
+!10 = distinct !DISubprogram(name: "foo", linkageName: "_Z3foov", scope: !1, file: !1, line: 5, type: !11, scopeLine: 5, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
+!11 = !DISubroutineType(types: !12)
+!12 = !{}
+!13 = !DILocation(line: 6, column: 3, scope: !10)
+!14 = !DILocation(line: 7, column: 3, scope: !10)
+!15 = !DILocation(line: 7, column: 9, scope: !10)
+!16 = !DILocation(line: 8, column: 1, scope: !10)
+!17 = !DISubprogram(name: "f1", linkageName: "_Z2f1v", scope: !1, file: !1, line: 1, type: !11, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+!18 = !DISubprogram(name: "f2", linkageName: "_Z2f2v", scope: !1, file: !1, line: 2, type: !11, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+!19 = !DISubprogram(name: "f3", linkageName: "_Z2f3v", scope: !1, file: !1, line: 3, type: !11, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+)IR";
+
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Ctx);
+  ASSERT_TRUE(M);
+
+  auto Calls = memprof::extractCallsFromIR(*M);
+
+  // Expect exactly one caller.
+  ASSERT_THAT(Calls, testing::SizeIs(1));
+
+  auto It = Calls.begin();
+  ASSERT_NE(It, Calls.end());
+
+  const auto &[CallerGUID, CallSites] = *It;
+  EXPECT_EQ(CallerGUID, memprof::IndexedMemProfRecord::getGUID("_Z3foov"));
+  ASSERT_THAT(CallSites, testing::SizeIs(3));
+
+  // Verify that call sites show up in the ascending order of their source
+  // locations.
+  EXPECT_EQ(CallSites[0].first.LineOffset, 1U);
+  EXPECT_EQ(CallSites[0].first.Column, 3U);
+  EXPECT_EQ(CallSites[0].second,
+            memprof::IndexedMemProfRecord::getGUID("_Z2f1v"));
+
+  EXPECT_EQ(CallSites[1].first.LineOffset, 2U);
+  EXPECT_EQ(CallSites[1].first.Column, 3U);
+  EXPECT_EQ(CallSites[1].second,
+            memprof::IndexedMemProfRecord::getGUID("_Z2f2v"));
+
+  EXPECT_EQ(CallSites[2].first.LineOffset, 2U);
+  EXPECT_EQ(CallSites[2].first.Column, 9U);
+  EXPECT_EQ(CallSites[2].second,
+            memprof::IndexedMemProfRecord::getGUID("_Z2f3v"));
+}
+} // namespace

>From 26ff649edd427b1997ec4b23bca6d46a47b426fb Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Wed, 6 Nov 2024 15:03:38 -0800
Subject: [PATCH 2/3] Address comments.

---
 .../Transforms/Instrumentation/MemProfiler.h  |  3 ---
 .../Instrumentation/MemProfiler.cpp           |  7 +++---
 .../Instrumentation/MemProfUseTest.cpp        | 23 ++++++++-----------
 3 files changed, 12 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
index 076a2785bbaa77..f168ffc4fdb1ef 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
@@ -62,9 +62,6 @@ namespace memprof {
 struct LineLocation {
   LineLocation(uint32_t L, uint32_t D) : LineOffset(L), Column(D) {}
 
-  void print(raw_ostream &OS) const;
-  void dump() const;
-
   bool operator<(const LineLocation &O) const {
     return LineOffset < O.LineOffset ||
            (LineOffset == O.LineOffset && Column < O.Column);
diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index fef11d9ffe306f..ec26b435d9e56f 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -833,10 +833,9 @@ memprof::extractCallsFromIR(Module &M) {
   }
 
   // Sort each call list by the source location.
-  for (auto &KV : Calls) {
-    auto &Calls = KV.second;
-    llvm::sort(Calls);
-    Calls.erase(llvm::unique(Calls), Calls.end());
+  for (auto &[CallerGUID, CallList] : Calls) {
+    llvm::sort(CallList);
+    CallList.erase(llvm::unique(CallList), CallList.end());
   }
 
   return Calls;
diff --git a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
index 21c7537852c4df..f2287f3d6b2409 100644
--- a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
+++ b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
@@ -90,19 +90,14 @@ declare !dbg !19 void @_Z2f3v()
 
   // Verify that call sites show up in the ascending order of their source
   // locations.
-  EXPECT_EQ(CallSites[0].first.LineOffset, 1U);
-  EXPECT_EQ(CallSites[0].first.Column, 3U);
-  EXPECT_EQ(CallSites[0].second,
-            memprof::IndexedMemProfRecord::getGUID("_Z2f1v"));
-
-  EXPECT_EQ(CallSites[1].first.LineOffset, 2U);
-  EXPECT_EQ(CallSites[1].first.Column, 3U);
-  EXPECT_EQ(CallSites[1].second,
-            memprof::IndexedMemProfRecord::getGUID("_Z2f2v"));
-
-  EXPECT_EQ(CallSites[2].first.LineOffset, 2U);
-  EXPECT_EQ(CallSites[2].first.Column, 9U);
-  EXPECT_EQ(CallSites[2].second,
-            memprof::IndexedMemProfRecord::getGUID("_Z2f3v"));
+  EXPECT_THAT(CallSites[0],
+              testing::Pair(testing::FieldsAre(1U, 3U),
+                            memprof::IndexedMemProfRecord::getGUID("_Z2f1v")));
+  EXPECT_THAT(CallSites[1],
+              testing::Pair(testing::FieldsAre(2U, 3U),
+                            memprof::IndexedMemProfRecord::getGUID("_Z2f2v")));
+  EXPECT_THAT(CallSites[2],
+              testing::Pair(testing::FieldsAre(2U, 9U),
+                            memprof::IndexedMemProfRecord::getGUID("_Z2f3v")));
 }
 } // namespace

>From 77d02bb813e9a38fe4e10d62e8c7d37e6aab004d Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Thu, 7 Nov 2024 11:20:40 -0800
Subject: [PATCH 3/3] Address more comments.

---
 .../Instrumentation/MemProfiler.cpp           |  1 +
 .../Instrumentation/MemProfUseTest.cpp        | 21 ++++++++++---------
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index ec26b435d9e56f..0b4d3ff201e622 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -819,6 +819,7 @@ memprof::extractCallsFromIR(Module &M) {
 
         auto *CB = dyn_cast<CallBase>(&I);
         auto *CalledFunction = CB->getCalledFunction();
+        // Disregard indirect calls and intrinsics.
         if (!CalledFunction || CalledFunction->isIntrinsic())
           continue;
 
diff --git a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
index f2287f3d6b2409..a510a57099aba4 100644
--- a/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
+++ b/llvm/unittests/Transforms/Instrumentation/MemProfUseTest.cpp
@@ -18,6 +18,10 @@
 
 namespace {
 using namespace llvm;
+using namespace llvm::memprof;
+using testing::FieldsAre;
+using testing::Pair;
+using testing::SizeIs;
 
 TEST(MemProf, ExtractDirectCallsFromIR) {
   // The following IR is generated from:
@@ -76,28 +80,25 @@ declare !dbg !19 void @_Z2f3v()
   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Ctx);
   ASSERT_TRUE(M);
 
-  auto Calls = memprof::extractCallsFromIR(*M);
+  auto Calls = extractCallsFromIR(*M);
 
   // Expect exactly one caller.
-  ASSERT_THAT(Calls, testing::SizeIs(1));
+  ASSERT_THAT(Calls, SizeIs(1));
 
   auto It = Calls.begin();
   ASSERT_NE(It, Calls.end());
 
   const auto &[CallerGUID, CallSites] = *It;
-  EXPECT_EQ(CallerGUID, memprof::IndexedMemProfRecord::getGUID("_Z3foov"));
-  ASSERT_THAT(CallSites, testing::SizeIs(3));
+  EXPECT_EQ(CallerGUID, IndexedMemProfRecord::getGUID("_Z3foov"));
+  ASSERT_THAT(CallSites, SizeIs(3));
 
   // Verify that call sites show up in the ascending order of their source
   // locations.
   EXPECT_THAT(CallSites[0],
-              testing::Pair(testing::FieldsAre(1U, 3U),
-                            memprof::IndexedMemProfRecord::getGUID("_Z2f1v")));
+              Pair(FieldsAre(1U, 3U), IndexedMemProfRecord::getGUID("_Z2f1v")));
   EXPECT_THAT(CallSites[1],
-              testing::Pair(testing::FieldsAre(2U, 3U),
-                            memprof::IndexedMemProfRecord::getGUID("_Z2f2v")));
+              Pair(FieldsAre(2U, 3U), IndexedMemProfRecord::getGUID("_Z2f2v")));
   EXPECT_THAT(CallSites[2],
-              testing::Pair(testing::FieldsAre(2U, 9U),
-                            memprof::IndexedMemProfRecord::getGUID("_Z2f3v")));
+              Pair(FieldsAre(2U, 9U), IndexedMemProfRecord::getGUID("_Z2f3v")));
 }
 } // namespace



More information about the llvm-commits mailing list