[llvm] 2e426fe - Add unit tests for size returning new funcs in the MemProf use pass. (#105473)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 09:43:06 PDT 2024


Author: Snehasish Kumar
Date: 2024-08-26T09:43:03-07:00
New Revision: 2e426fe8ff314c2565073e73e27fdbdf36c140a3

URL: https://github.com/llvm/llvm-project/commit/2e426fe8ff314c2565073e73e27fdbdf36c140a3
DIFF: https://github.com/llvm/llvm-project/commit/2e426fe8ff314c2565073e73e27fdbdf36c140a3.diff

LOG: Add unit tests for size returning new funcs in the MemProf use pass. (#105473)

We use a unit test to verify correctness since:
a) we don't have a text format profile
b) size returning new isn't supported natively
c) a raw profile will need to be manipulated artificially

The changes this test covers were made in
https://github.com/llvm/llvm-project/pull/102258.

Added: 
    llvm/unittests/Transforms/Instrumentation/MemProfilerTest.cpp

Modified: 
    llvm/include/llvm/ProfileData/InstrProfReader.h
    llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
    llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
    llvm/unittests/Transforms/Instrumentation/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ProfileData/InstrProfReader.h b/llvm/include/llvm/ProfileData/InstrProfReader.h
index 3b307d08359980..95c891442fd6e9 100644
--- a/llvm/include/llvm/ProfileData/InstrProfReader.h
+++ b/llvm/include/llvm/ProfileData/InstrProfReader.h
@@ -670,10 +670,11 @@ class IndexedMemProfReader {
 
 public:
   IndexedMemProfReader() = default;
+  virtual ~IndexedMemProfReader() = default;
 
   Error deserialize(const unsigned char *Start, uint64_t MemProfOffset);
 
-  Expected<memprof::MemProfRecord>
+  virtual Expected<memprof::MemProfRecord>
   getMemProfRecord(const uint64_t FuncNameHash) const;
 };
 
@@ -768,11 +769,14 @@ class IndexedInstrProfReader : public InstrProfReader {
                      uint64_t *MismatchedFuncSum = nullptr);
 
   /// Return the memprof record for the function identified by
-  /// llvm::md5(Name).
+  /// llvm::md5(Name). Marked virtual so that unit tests can mock this function.
   Expected<memprof::MemProfRecord> getMemProfRecord(uint64_t FuncNameHash) {
     return MemProfReader.getMemProfRecord(FuncNameHash);
   }
 
+  /// Return the underlying memprof reader.
+  IndexedMemProfReader &getIndexedMemProfReader() { return MemProfReader; }
+
   /// Fill Counts with the profile data for the given function name.
   Error getFunctionCounts(StringRef FuncName, uint64_t FuncHash,
                           std::vector<uint64_t> &Counts);

diff  --git a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
index f92c6b4775a2a2..c5d03c98f41581 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h
@@ -13,15 +13,15 @@
 #define LLVM_TRANSFORMS_INSTRUMENTATION_MEMPROFILER_H
 
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/IR/ModuleSummaryIndex.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/ProfileData/InstrProfReader.h"
+#include "llvm/Support/VirtualFileSystem.h"
 
 namespace llvm {
 class Function;
 class Module;
-
-namespace vfs {
-class FileSystem;
-} // namespace vfs
+class TargetLibraryInfo;
 
 /// Public interface to the memory profiler pass for instrumenting code to
 /// profile memory accesses.
@@ -52,6 +52,17 @@ class MemProfUsePass : public PassInfoMixin<MemProfUsePass> {
                           IntrusiveRefCntPtr<vfs::FileSystem> FS = nullptr);
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 
+  struct AllocMatchInfo {
+    uint64_t TotalSize = 0;
+    AllocationType AllocType = AllocationType::None;
+    bool Matched = false;
+  };
+
+  void
+  readMemprof(Function &F, const IndexedMemProfReader &MemProfReader,
+              const TargetLibraryInfo &TLI,
+              std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo);
+
 private:
   std::string MemoryProfileFileName;
   IntrusiveRefCntPtr<vfs::FileSystem> FS;

diff  --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
index 4a43120c9a9e7f..bd10c037ecf4ad 100644
--- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp
@@ -39,7 +39,6 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/HashBuilder.h"
-#include "llvm/Support/VirtualFileSystem.h"
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -55,6 +54,7 @@ namespace llvm {
 extern cl::opt<bool> PGOWarnMissing;
 extern cl::opt<bool> NoPGOWarnMismatch;
 extern cl::opt<bool> NoPGOWarnMismatchComdatWeak;
+using AllocMatchInfo = ::llvm::MemProfUsePass::AllocMatchInfo;
 } // namespace llvm
 
 constexpr int LLVM_MEM_PROFILER_VERSION = 1;
@@ -148,10 +148,11 @@ static cl::opt<int> ClDebugMax("memprof-debug-max", cl::desc("Debug max inst"),
 
 // By default disable matching of allocation profiles onto operator new that
 // already explicitly pass a hot/cold hint, since we don't currently
-// override these hints anyway.
-static cl::opt<bool> ClMemProfMatchHotColdNew(
+// override these hints anyway. Not static so that it can be set in the unit
+// test too.
+cl::opt<bool> ClMemProfMatchHotColdNew(
     "memprof-match-hot-cold-new",
- cl::desc(
+    cl::desc(
         "Match allocation profiles onto existing hot/cold operator new calls"),
     cl::Hidden, cl::init(false));
 
@@ -789,17 +790,11 @@ static bool isAllocationWithHotColdVariant(Function *Callee,
   }
 }
 
-struct AllocMatchInfo {
-  uint64_t TotalSize = 0;
-  AllocationType AllocType = AllocationType::None;
-  bool Matched = false;
-};
-
-static void
-readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
-            const TargetLibraryInfo &TLI,
-            std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo) {
-  auto &Ctx = M.getContext();
+void MemProfUsePass::readMemprof(
+    Function &F, const IndexedMemProfReader &MemProfReader,
+    const TargetLibraryInfo &TLI,
+    std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo) {
+  auto &Ctx = F.getContext();
   // Previously we used getIRPGOFuncName() here. If F is local linkage,
   // getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But
   // llvm-profdata uses FuncName in dwarf to create GUID which doesn't
@@ -810,7 +805,7 @@ readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
   auto FuncName = F.getName();
   auto FuncGUID = Function::getGUID(FuncName);
   std::optional<memprof::MemProfRecord> MemProfRec;
-  auto Err = MemProfReader->getMemProfRecord(FuncGUID).moveInto(MemProfRec);
+  auto Err = MemProfReader.getMemProfRecord(FuncGUID).moveInto(MemProfRec);
   if (Err) {
     handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) {
       auto Err = IPE.get();
@@ -838,8 +833,8 @@ readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
                          Twine(" Hash = ") + std::to_string(FuncGUID))
                             .str();
 
-      Ctx.diagnose(
-          DiagnosticInfoPGOProfile(M.getName().data(), Msg, DS_Warning));
+      Ctx.diagnose(DiagnosticInfoPGOProfile(F.getParent()->getName().data(),
+                                            Msg, DS_Warning));
     });
     return;
   }
@@ -1036,15 +1031,15 @@ PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
     return PreservedAnalyses::all();
   }
 
-  std::unique_ptr<IndexedInstrProfReader> MemProfReader =
+  std::unique_ptr<IndexedInstrProfReader> IndexedReader =
       std::move(ReaderOrErr.get());
-  if (!MemProfReader) {
+  if (!IndexedReader) {
     Ctx.diagnose(DiagnosticInfoPGOProfile(
-        MemoryProfileFileName.data(), StringRef("Cannot get MemProfReader")));
+        MemoryProfileFileName.data(), StringRef("Cannot get IndexedReader")));
     return PreservedAnalyses::all();
   }
 
-  if (!MemProfReader->hasMemoryProfile()) {
+  if (!IndexedReader->hasMemoryProfile()) {
     Ctx.diagnose(DiagnosticInfoPGOProfile(MemoryProfileFileName.data(),
                                           "Not a memory profile"));
     return PreservedAnalyses::all();
@@ -1057,12 +1052,13 @@ PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
   // it to an allocation in the IR.
   std::map<uint64_t, AllocMatchInfo> FullStackIdToAllocMatchInfo;
 
+  const auto &MemProfReader = IndexedReader->getIndexedMemProfReader();
   for (auto &F : M) {
     if (F.isDeclaration())
       continue;
 
     const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
-    readMemprof(M, F, MemProfReader.get(), TLI, FullStackIdToAllocMatchInfo);
+    readMemprof(F, MemProfReader, TLI, FullStackIdToAllocMatchInfo);
   }
 
   if (ClPrintMemProfMatchInfo) {

diff  --git a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
index 1f249b0049d062..1afe1c339e4335 100644
--- a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(InstrumentationTests
   PGOInstrumentationTest.cpp
+  MemProfilerTest.cpp
   )
 
 target_link_libraries(InstrumentationTests PRIVATE LLVMTestingSupport)

diff  --git a/llvm/unittests/Transforms/Instrumentation/MemProfilerTest.cpp b/llvm/unittests/Transforms/Instrumentation/MemProfilerTest.cpp
new file mode 100644
index 00000000000000..844867d676e8dd
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/MemProfilerTest.cpp
@@ -0,0 +1,158 @@
+//===- MemProfilerTest.cpp - MemProfiler unit tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/MemProfiler.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProfReader.h"
+#include "llvm/ProfileData/MemProf.h"
+#include "llvm/ProfileData/MemProfData.inc"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+extern llvm::cl::opt<bool> ClMemProfMatchHotColdNew;
+
+namespace llvm {
+namespace memprof {
+namespace {
+
+using ::testing::Return;
+using ::testing::SizeIs;
+
+struct MemProfilerTest : public ::testing::Test {
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+
+  MemProfilerTest() { ClMemProfMatchHotColdNew = true; }
+
+  void parseAssembly(const StringRef IR) {
+    SMDiagnostic Error;
+    M = parseAssemblyString(IR, Error, Context);
+    std::string ErrMsg;
+    raw_string_ostream OS(ErrMsg);
+    Error.print("", OS);
+
+    // A failure here means that the test itself is buggy.
+    if (!M)
+      report_fatal_error(OS.str().c_str());
+  }
+};
+
+// A mock memprof reader we can inject into the function we are testing.
+class MockMemProfReader : public IndexedMemProfReader {
+public:
+  MOCK_METHOD(Expected<MemProfRecord>, getMemProfRecord,
+              (const uint64_t FuncNameHash), (const, override));
+
+  // A helper function to create mock records from frames.
+  static MemProfRecord makeRecord(ArrayRef<ArrayRef<Frame>> AllocFrames) {
+    MemProfRecord Record;
+    MemInfoBlock Info;
+    // Mimic values which will be below the cold threshold.
+    Info.AllocCount = 1, Info.TotalSize = 550;
+    Info.TotalLifetime = 1000 * 1000, Info.TotalLifetimeAccessDensity = 1;
+    for (const auto &Callstack : AllocFrames) {
+      AllocationInfo AI;
+      AI.Info = PortableMemInfoBlock(Info, getHotColdSchema());
+      AI.CallStack = std::vector(Callstack.begin(), Callstack.end());
+      Record.AllocSites.push_back(AI);
+    }
+    return Record;
+  }
+};
+
+TEST_F(MemProfilerTest, AnnotatesCall) {
+  parseAssembly(R"IR(
+    target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+    target triple = "x86_64-unknown-linux-gnu"
+
+    define void @_Z3foov() !dbg !10 {
+    entry:
+      %c1 = call {ptr, i64} @__size_returning_new(i64 32), !dbg !13
+      %c2 = call {ptr, i64} @__size_returning_new_aligned(i64 32, i64 8), !dbg !14
+      %c3 = call {ptr, i64} @__size_returning_new_hot_cold(i64 32, i8 254), !dbg !15
+      %c4 = call {ptr, i64} @__size_returning_new_aligned_hot_cold(i64 32, i64 8, i8 254), !dbg !16
+      ret void
+    }
+
+    declare {ptr, i64} @__size_returning_new(i64)
+    declare {ptr, i64} @__size_returning_new_aligned(i64, i64)
+    declare {ptr, i64} @__size_returning_new_hot_cold(i64, i8)
+    declare {ptr, i64} @__size_returning_new_aligned_hot_cold(i64, i64, i8)
+
+    !llvm.dbg.cu = !{!0}
+    !llvm.module.flags = !{!2, !3}
+
+    !0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1)
+    !1 = !DIFile(filename: "mock_file.cc", directory: "mock_dir")
+    !2 = !{i32 7, !"Dwarf Version", i32 5}
+    !3 = !{i32 2, !"Debug Info Version", i32 3}
+    !10 = distinct !DISubprogram(name: "foo", linkageName: "_Z3foov", scope: !1, file: !1, line: 4, type: !11, scopeLine: 4, unit: !0, retainedNodes: !12)
+    !11 = !DISubroutineType(types: !12)
+    !12 = !{}
+    !13 = !DILocation(line: 5, column: 10, scope: !10)
+    !14 = !DILocation(line: 6, column: 10, scope: !10)
+    !15 = !DILocation(line: 7, column: 10, scope: !10)
+    !16 = !DILocation(line: 8, column: 10, scope: !10)
+  )IR");
+
+  auto *F = M->getFunction("_Z3foov");
+  ASSERT_NE(F, nullptr);
+
+  TargetLibraryInfoWrapperPass WrapperPass;
+  auto &TLI = WrapperPass.getTLI(*F);
+
+  auto Guid = Function::getGUID("_Z3foov");
+  // All the allocation sites are in foo().
+  MemProfRecord MockRecord =
+      MockMemProfReader::makeRecord({{Frame(Guid, 1, 10, false)},
+                                     {Frame(Guid, 2, 10, false)},
+                                     {Frame(Guid, 3, 10, false)},
+                                     {Frame(Guid, 4, 10, false)}});
+  // Set up mocks for the reader.
+  MockMemProfReader Reader;
+  EXPECT_CALL(Reader, getMemProfRecord(Guid)).WillOnce(Return(MockRecord));
+
+  MemProfUsePass Pass("/unused/profile/path");
+  std::map<uint64_t, MemProfUsePass::AllocMatchInfo> Unused;
+  Pass.readMemprof(*F, Reader, TLI, Unused);
+
+  // Since we only have a single type of behaviour for each allocation site, we
+  // only get function attributes.
+  std::vector<llvm::Attribute> CallsiteAttrs;
+  for (const auto &BB : *F) {
+    for (const auto &I : BB) {
+      if (auto *CI = dyn_cast<CallInst>(&I)) {
+        if (!CI->getCalledFunction()->getName().starts_with(
+                "__size_returning_new"))
+          continue;
+        Attribute Attr = CI->getFnAttr("memprof");
+        // The attribute will be invalid if it didn't find one named memprof.
+        ASSERT_TRUE(Attr.isValid());
+        CallsiteAttrs.push_back(Attr);
+      }
+    }
+  }
+
+  // We match all the variants including ones with the hint since we set
+  // ClMemProfMatchHotColdNew to true.
+  EXPECT_THAT(CallsiteAttrs, SizeIs(4));
+}
+
+} // namespace
+} // namespace memprof
+} // namespace llvm


        


More information about the llvm-commits mailing list