[llvm] [Sample Profile] Expand functionality of llvm-profdata function filter (PR #101615)

William Junda Huang via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 10 01:44:29 PDT 2024


https://github.com/huangjd updated https://github.com/llvm/llvm-project/pull/101615

>From ab3010436bed276fa10e863be868aca36338cd46 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Fri, 2 Aug 2024 01:04:38 -0400
Subject: [PATCH 1/2] [Sample Profile] Expand functionality of llvm-profdata
 function filter

llvm-profdata merge can now filter inlined function or call targets when
the user specifies a fully qualified canonical form. For example
`main:2 @ foo:3.1 @ bar`. See updated doc for syntax. Regex is also
supported.
---
 llvm/docs/CommandGuide/llvm-profdata.rst      |  48 +++++-
 llvm/include/llvm/ProfileData/SampleProf.h    |  26 +++-
 llvm/lib/ProfileData/SampleProf.cpp           | 105 +++++++++++++
 .../Inputs/merge-filter.proftext              |  13 ++
 .../tools/llvm-profdata/merge-filter.test     |  32 +++-
 llvm/tools/llvm-profdata/llvm-profdata.cpp    | 140 ++++++++++++------
 6 files changed, 308 insertions(+), 56 deletions(-)
 create mode 100644 llvm/test/tools/llvm-profdata/Inputs/merge-filter.proftext

diff --git a/llvm/docs/CommandGuide/llvm-profdata.rst b/llvm/docs/CommandGuide/llvm-profdata.rst
index acf016a6dbcd70..af5aeec2df9c62 100644
--- a/llvm/docs/CommandGuide/llvm-profdata.rst
+++ b/llvm/docs/CommandGuide/llvm-profdata.rst
@@ -219,13 +219,16 @@ OPTIONS
 
 .. option:: --function=<string>
 
- Only keep functions matching the regex in the output, all others are erased
- from the profile.
+ Only keep functions matching the filter in the output, all others are erased
+ from the profile. If the filter string is quoted and escaped, it is treated as
+ regex, otherwise it is treated as exact match. For sample profile, see
+ :ref:`Sample Profile Canonical Function Name <profdata-canonical>` how to match
+ inlined function.
 
 .. option:: --no-function=<string>
 
- Remove functions matching the regex from the profile. If both --function and
- --no-function are specified and a function matches both, it is removed.
+ Remove functions matching the filter from the profile. If a function matches
+ both filters specified by --function and --no-function, it is removed.
 
 EXAMPLES
 ^^^^^^^^
@@ -475,3 +478,40 @@ EXIT STATUS
 
 :program:`llvm-profdata` returns 1 if the command is omitted or is invalid,
 if it cannot read input files, or if there is a mismatch between their data.
+
+.. _profdata-canonical:
+
+SAMPLE PROFILE CANONICAL FUNCTION NAME
+--------------------------------------
+
+The canonical name of a (possibly inlined) function in sample profile is
+defined as the following:
+
+::
+
+    <canonical name> ::= <top level name>
+                     ::= <top level name> <inlined callsites>
+
+    <top level name> ::= <FunctionID> # for context-less profile
+                     ::= "[" (<FunctionID> | <FunctionID> <inlined callsites>) "]" # for CSSPGO profile
+
+    <inlined callsites> ::= <SampleContextFrame>
+                        ::= <SampleContextFrame> <inlined callsites>
+
+    <SampleContextFrame> ::= ":" <LineLocation> " @ " <FunctionID>
+
+    <LineLocation> ::= <number>
+                   ::= <number> "." <number> # with discriminator
+
+    <FunctionID> ::= <number> # for MD5 profile
+                 ::= <string> # for non-MD5 profile
+
+For brevity we do not check for the validity of the mangled name, and the only
+restrictions are that it cannot begin with a number or space, and it cannot
+contain the inline context separator " @ ".
+
+The canonical name of a call taget is defined as:
+
+::
+
+    <call target canonical name> ::= <canonical name> ":" <LineLocation> " @@ " <FunctionID>
diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index e7b154dff06971..7c5ff50ec5206e 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -41,6 +41,7 @@ namespace llvm {
 
 class DILocation;
 class raw_ostream;
+class Regex;
 
 const std::error_category &sampleprof_category();
 
@@ -300,6 +301,15 @@ struct LineLocation {
     return ((uint64_t) Discriminator << 32) | LineOffset;
   }
 
+  std::string toString() const {
+    std::string S = std::to_string(LineOffset);
+    if (Discriminator != 0) {
+      S.push_back('.');
+      S += std::to_string(Discriminator);
+    }
+    return S;
+  }
+
   uint32_t LineOffset;
   uint32_t Discriminator;
 };
@@ -757,11 +767,10 @@ class FunctionSamples {
                       : sampleprof_error::success;
   }
 
-  void removeTotalSamples(uint64_t Num) {
-    if (TotalSamples < Num)
-      TotalSamples = 0;
-    else
-      TotalSamples -= Num;
+  uint64_t removeTotalSamples(uint64_t Num) {
+    Num = std::min(Num, TotalSamples);
+    TotalSamples -= Num;
+    return Num;
   }
 
   void setTotalSamples(uint64_t Num) { TotalSamples = Num; }
@@ -1216,6 +1225,13 @@ class FunctionSamples {
   // all the inline instances and names of call targets.
   void findAllNames(DenseSet<FunctionId> &NameSet) const;
 
+  /// Traverse inlined callsites recursively, and erase those with matching
+  /// canonical representation (or do the opposite, if EraseMatch is false).
+  /// Returns total number of samples removed.
+  uint64_t eraseInlinedCallsites(const llvm::Regex &Re,
+                                 std::string &CanonicalName,
+                                 bool MatchCallTargets, bool EraseMatch);
+
   bool operator==(const FunctionSamples &Other) const {
     return (GUIDToFuncNameMap == Other.GUIDToFuncNameMap ||
             (GUIDToFuncNameMap && Other.GUIDToFuncNameMap &&
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index addb473faebdff..c2cf744ca1ca9e 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 #include <string>
 #include <system_error>
@@ -278,6 +279,110 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
   }
 }
 
+namespace {
+// A class to keep a string invariant where it is appended with contents in a
+// recursive function.
+struct ScopedString {
+  std::string &String;
+
+  size_t Length;
+
+  ScopedString(std::string &String) : String(String), Length(String.length()) {}
+
+  ~ScopedString() {
+    assert(String.length() > Length);
+    String.resize(Length);
+  }
+};
+} // namespace
+
+// CanonicalName is invariant in this function. It should already contain the
+// canonical name of the current FunctionSamples.
+uint64_t FunctionSamples::eraseInlinedCallsites(const llvm::Regex &Re,
+                                                std::string &CanonicalName,
+                                                bool MatchCallTargets,
+                                                bool EraseMatch) {
+  uint64_t Result = 0;
+
+  ScopedString SaveString1(CanonicalName);
+  CanonicalName.push_back(':');
+
+  if (MatchCallTargets) {
+    // Check matching call targets.
+    for (auto BodySampleIt = BodySamples.begin();
+         BodySampleIt != BodySamples.end();) {
+      auto &[Loc, BodySample] = *BodySampleIt;
+      SampleRecord::CallTargetMap &CallTargets =
+          const_cast<SampleRecord::CallTargetMap &>(
+              BodySample.getCallTargets());
+      if (!CallTargets.empty()) {
+        ScopedString SaveString2(CanonicalName);
+        CanonicalName += Loc.toString();
+        CanonicalName += " @@ ";
+        uint64_t RemovedCallTargetCount = 0;
+        for (auto CallTargetIt = CallTargets.begin();
+             CallTargetIt != CallTargets.end();) {
+          ScopedString SaveString3(CanonicalName);
+          CanonicalName += CallTargetIt->first.str();
+
+          if (Re.match(CanonicalName) == EraseMatch) {
+            RemovedCallTargetCount += CallTargetIt->second;
+            CallTargetIt = CallTargets.erase(CallTargetIt);
+          } else
+            ++CallTargetIt;
+        }
+        // Adjust sample count as if they were removed with
+        // removeCalledTargetAndBodySample.
+        Result += BodySample.removeSamples(RemovedCallTargetCount);
+        if (BodySample.getSamples() == 0) {
+          BodySampleIt = BodySamples.erase(BodySampleIt);
+          continue;
+        }
+      }
+      ++BodySampleIt;
+    }
+  }
+
+  // Check matching inlined callsites.
+  for (auto CallsiteSampleIt = CallsiteSamples.begin();
+       CallsiteSampleIt != CallsiteSamples.end();) {
+    auto &[Loc, FSMap] = *CallsiteSampleIt;
+    ScopedString SaveString2(CanonicalName);
+    CanonicalName += Loc.toString();
+    CanonicalName += " @ ";
+    for (auto FunctionSampleIt = FSMap.begin();
+         FunctionSampleIt != FSMap.end();) {
+      FunctionSamples &InlinedFS = FunctionSampleIt->second;
+      ScopedString SaveString3(CanonicalName);
+      CanonicalName.append(InlinedFS.getContext().toString());
+
+      if (Re.match(CanonicalName) == EraseMatch) {
+        Result += InlinedFS.getTotalSamples();
+        FunctionSampleIt = FSMap.erase(FunctionSampleIt);
+      } else {
+        // Recursively process inlined callsites.
+        Result += InlinedFS.eraseInlinedCallsites(Re, CanonicalName,
+                                                  MatchCallTargets, EraseMatch);
+        // If every sample in the inlined callsite is removed, remove the
+        // callsite as well.
+        if (InlinedFS.getTotalSamples() == 0) {
+          FunctionSampleIt = FSMap.erase(FunctionSampleIt);
+        } else
+          ++FunctionSampleIt;
+      }
+    }
+
+    // If FSMap has no more entries, remove it as well.
+    if (FSMap.empty())
+      CallsiteSampleIt = CallsiteSamples.erase(CallsiteSampleIt);
+    else
+      ++CallsiteSampleIt;
+  }
+
+  // Adjust total sample count after removals.
+  return removeTotalSamples(Result);
+}
+
 const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
     const LineLocation &Loc, StringRef CalleeName,
     SampleProfileReaderItaniumRemapper *Remapper,
diff --git a/llvm/test/tools/llvm-profdata/Inputs/merge-filter.proftext b/llvm/test/tools/llvm-profdata/Inputs/merge-filter.proftext
new file mode 100644
index 00000000000000..1f80811832a888
--- /dev/null
+++ b/llvm/test/tools/llvm-profdata/Inputs/merge-filter.proftext
@@ -0,0 +1,13 @@
+main:51312:11
+ 0: 6
+ 1: 6
+ 4: 50000 _Z3fibi:49999 _Z3fibv:1
+ 1: foo:1200
+  11: 1000
+  12: bar:100
+   121: 200
+  13: bat:100
+   131: 100
+ 2.100001: goo:100
+  21: bat:100
+   211: 100
diff --git a/llvm/test/tools/llvm-profdata/merge-filter.test b/llvm/test/tools/llvm-profdata/merge-filter.test
index 5c47c6a75a7c40..7bec3f29f3c59e 100644
--- a/llvm/test/tools/llvm-profdata/merge-filter.test
+++ b/llvm/test/tools/llvm-profdata/merge-filter.test
@@ -1,6 +1,6 @@
 Test llvm-profdata merge with function filters.
 
-RUN: llvm-profdata merge --sample %p/Inputs/sample-profile.proftext --text --function="_Z3.*" | FileCheck %s --check-prefix=CHECK-FILTER1
+RUN: llvm-profdata merge --sample %p/Inputs/sample-profile.proftext --text --function="\"_Z3.*\"" | FileCheck %s --check-prefix=CHECK-FILTER1
 RUN: llvm-profdata merge --sample %p/Inputs/sample-profile.proftext --text --no-function="main" | FileCheck %s --check-prefix=CHECK-FILTER1
 CHECK-FILTER1: _Z3bari:20301:1437
 CHECK-NEXT:  1: 1437
@@ -8,7 +8,7 @@ CHECK-NEXT: _Z3fooi:7711:610
 CHECK-NEXT:  1: 610
 CHECK-NOT: main
 
-RUN: llvm-profdata merge --sample %p/Inputs/sample-profile.proftext --text --function="_Z3.*" --no-function="fooi$" | FileCheck %s --check-prefix=CHECK-FILTER2
+RUN: llvm-profdata merge --sample %p/Inputs/sample-profile.proftext --text --function="\"_Z3.*\"" --no-function="\"fooi$\"" | FileCheck %s --check-prefix=CHECK-FILTER2
 CHECK-FILTER2: _Z3bari:20301:1437
 CHECK-NEXT:  1: 1437
 CHECK-NOT: main
@@ -21,6 +21,31 @@ CHECK-NEXT:  1: 610
 CHECK-NOT: 15822663052811949562
 CHECK-NOT: 3727899762981752933
 
+RUN: llvm-profdata merge --sample %p/Inputs/merge-filter.proftext --text --no-function="\"^main:.+ @ .oo:[0-9]+ @ bat\"" | FileCheck %s --check-prefix=CHECK-FILTER-INLINED
+CHECK-FILTER-INLINED: main:51112:11
+CHECK-NEXT:  0: 6
+CHECK-NEXT:  1: 6
+CHECK-NEXT:  4: 50000 _Z3fibi:49999 _Z3fibv:1
+CHECK-NEXT:   1: foo:1100
+CHECK-NEXT:   11: 1000
+CHECK-NEXT:   12: bar:100
+CHECK-NEXT:    121: 200
+
+RUN: llvm-profdata merge --sample %p/Inputs/merge-filter.proftext --text --no-function="main:4 @@ _Z3fibi" | FileCheck %s --check-prefix=CHECK-FILTER-CALLSITE
+CHECK-FILTER-CALLSITE: main:1313:11
+CHECK-NEXT:  0: 6
+CHECK-NEXT:  1: 6
+CHECK-NEXT:  4: 1 _Z3fibv:1
+CHECK-NEXT:  1: foo:1200
+CHECK-NEXT:   11: 1000
+CHECK-NEXT:   12: bar:100
+CHECK-NEXT:    121: 200
+CHECK-NEXT:   13: bat:100
+CHECK-NEXT:    131: 100
+CHECK-NEXT:  2.100001: goo:100
+CHECK-NEXT:   21: bat:100
+CHECK-NEXT:    211: 100
+
 RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --function="foo" | FileCheck %s --check-prefix=CHECK-FILTER3
 RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --no-function="main" | FileCheck %s --check-prefix=CHECK-FILTER3
 CHECK-FILTER3: foo
@@ -51,7 +76,7 @@ CHECK-NEXT: # Counter Values:
 CHECK-NEXT: 500500
 CHECK-NEXT: 180100
 
-RUN: llvm-profdata merge --sample %p/Inputs/cs-sample.proftext --text --function="main.*@.*_Z5funcBi" | FileCheck %s --check-prefix=CHECK-FILTER5
+RUN: llvm-profdata merge --sample %p/Inputs/cs-sample.proftext --text --function="\"\[main.* @ .*_Z5funcBi.*\]\"" | FileCheck %s --check-prefix=CHECK-FILTER5
 CHECK-FILTER5: [main:3.1 @ _Z5funcBi:1 @ _Z8funcLeafi]:500853:20
 CHECK-NEXT:  0: 15
 CHECK-NEXT:  1: 15
@@ -66,4 +91,3 @@ CHECK-NEXT:  0: 19
 CHECK-NEXT:  1: 19 _Z8funcLeafi:20
 CHECK-NEXT:  3: 12
 CHECK-NEXT:  !Attributes: 1
-
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index 1f6c4c604d57b5..f4b68c7c97379c 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -134,7 +134,12 @@ cl::opt<std::string> FuncNameFilter(
     "function",
     cl::desc("Only functions matching the filter are shown in the output. For "
              "overlapping CSSPGO, this takes a function name with calling "
-             "context."),
+             "context. For merge, the filter applies to the text format "
+             "representation of the function, for example, _ZN3fooEv for "
+             "contextless, [_ZN3fooEv:2.1 @ bar:1 @ baz] for CSSPGO, and "
+             "_ZN3fooEv:2.1 @ bar:1 @ baz to filter inlined function. Use "
+             "quoted string for regex match. See detailed documentation of the "
+             "usage in https://llvm.org/docs/CommandGuide/llvm-profdata.html."),
     cl::sub(ShowSubcommand), cl::sub(OverlapSubcommand),
     cl::sub(MergeSubcommand));
 
@@ -824,59 +829,108 @@ static void mergeWriterContexts(WriterContext *Dst, WriterContext *Src) {
   });
 }
 
-static StringRef
-getFuncName(const StringMap<InstrProfWriter::ProfilingData>::value_type &Val) {
-  return Val.first();
+// Limitation: Wildcard may cause unexpected regex match, for example,
+// "foo.*bar:1 @ baz" may match "foo:1 @ bar:1 @ baz". The user should specify
+// regex pattern in a way not to match strings that are not valid mangled names.
+static void filterFunctions(SampleProfileMap &Profiles,
+                            std::string FilterString, bool EraseMatch) {
+  // Checking all call targets is very slow, only do this if FilterString can
+  // ever match a call target.
+  bool MatchCallTargets = (FilterString.find(" @@ ") != std::string::npos);
+
+  // Search inlined callsites recursively is extremely slow, only do this if
+  // FilterString has more than one part delimited by " @ " (except for CSSPGO
+  // top level function, we will check for that later) or " @@ ".
+  bool SearchInlinedCallsites =
+      MatchCallTargets || (FilterString.find(" @ ") != std::string::npos);
+
+  uint64_t MD5 = 0;
+
+  // If Pattern is quoted string, treat it as escaped regex, otherwise treat it
+  // as literal match.
+  if (FilterString[0] == '\"') {
+    if (FilterString.size() < 2 || FilterString.back() != '\"')
+      exitWithError("missing terminating '\"' character");
+    FilterString = FilterString.substr(1, FilterString.length() - 2);
+
+    // If pattern is "\[.*\]", it is CSSPGO top level function.
+    if (FilterString[0] == '\\' && FilterString[1] == '[' &&
+        FilterString[FilterString.size() - 2] == '\\' &&
+        FilterString[FilterString.size() - 1] == ']')
+      SearchInlinedCallsites = false;
+  } else {
+    // If pattern is "[.*]", it is CSSPGO top level function.
+    if (FilterString[0] == '[' && FilterString[FilterString.size() - 1] == ']')
+      SearchInlinedCallsites = false;
+
+    // Handle MD5 profile as well if possible. Obviously it only makes sense if
+    // FilterString only matches top level function and is plain text only.
+    if (!SearchInlinedCallsites &&
+        !std::all_of(FilterString.begin(), FilterString.end(), ::isdigit)) {
+      std::list<SampleContextFrameVector> CSNameTable;
+      MD5 = SampleContext(FilterString, CSNameTable).getHashCode();
+    }
+
+    // Mangled name can contain `?` (MSVC), `.` (LLVM suffix), or `[]` (CSSPGO).
+    FilterString = "^" + llvm::Regex::escape(FilterString) + "$";
+  }
+
+  llvm::Regex Re(FilterString);
+  if (std::string Error; !Re.isValid(Error))
+    exitWithError(Error);
+
+  for (auto FS = Profiles.begin(); FS != Profiles.end();) {
+    std::string CanonicalName = FS->second.getContext().toString();
+    if (FS->second.getContext().hasContext())
+      CanonicalName = "[" + CanonicalName + "]";
+    if ((Re.match(CanonicalName) ||
+         (FunctionSamples::UseMD5 &&
+          FS->second.getContext().getHashCode() == MD5)) == EraseMatch) {
+      FS = Profiles.erase(FS);
+      continue;
+    }
+    if (SearchInlinedCallsites)
+      FS->second.eraseInlinedCallsites(Re, CanonicalName, MatchCallTargets,
+                                       EraseMatch);
+    FS++;
+  }
 }
 
-static std::string
-getFuncName(const SampleProfileMap::value_type &Val) {
-  return Val.second.getContext().toString();
+static void filterFunctions(StringMap<InstrProfWriter::ProfilingData> &Profiles,
+                            std::string FilterString, bool EraseMatch) {
+  // If Pattern is quoted string, treat it as escaped regex, otherwise treat it
+  // as literal match.
+  if (FilterString[0] == '\"') {
+    if (FilterString.size() < 2 || FilterString.back() != '\"')
+      exitWithError("missing terminating '\"' character");
+    FilterString = FilterString.substr(1, FilterString.length() - 2);
+  }
+
+  llvm::Regex Re(FilterString);
+  if (std::string Error; !Re.isValid(Error))
+    exitWithError(Error);
+
+  for (auto ProfileIt = Profiles.begin(); ProfileIt != Profiles.end();) {
+    auto Tmp = ProfileIt++;
+    if (Re.match(Tmp->first()) == EraseMatch)
+      Profiles.erase(Tmp);
+  }
 }
 
-template <typename T>
-static void filterFunctions(T &ProfileMap) {
+template <typename T> static void filterFunctions(T &Profiles) {
   bool hasFilter = !FuncNameFilter.empty();
   bool hasNegativeFilter = !FuncNameNegativeFilter.empty();
   if (!hasFilter && !hasNegativeFilter)
     return;
 
-  // If filter starts with '?' it is MSVC mangled name, not a regex.
-  llvm::Regex ProbablyMSVCMangledName("[?@$_0-9A-Za-z]+");
-  if (hasFilter && FuncNameFilter[0] == '?' &&
-      ProbablyMSVCMangledName.match(FuncNameFilter))
-    FuncNameFilter = llvm::Regex::escape(FuncNameFilter);
-  if (hasNegativeFilter && FuncNameNegativeFilter[0] == '?' &&
-      ProbablyMSVCMangledName.match(FuncNameNegativeFilter))
-    FuncNameNegativeFilter = llvm::Regex::escape(FuncNameNegativeFilter);
-
-  size_t Count = ProfileMap.size();
-  llvm::Regex Pattern(FuncNameFilter);
-  llvm::Regex NegativePattern(FuncNameNegativeFilter);
-  std::string Error;
-  if (hasFilter && !Pattern.isValid(Error))
-    exitWithError(Error);
-  if (hasNegativeFilter && !NegativePattern.isValid(Error))
-    exitWithError(Error);
+  size_t Count = Profiles.size();
 
-  // Handle MD5 profile, so it is still able to match using the original name.
-  std::string MD5Name = std::to_string(llvm::MD5Hash(FuncNameFilter));
-  std::string NegativeMD5Name =
-      std::to_string(llvm::MD5Hash(FuncNameNegativeFilter));
-
-  for (auto I = ProfileMap.begin(); I != ProfileMap.end();) {
-    auto Tmp = I++;
-    const auto &FuncName = getFuncName(*Tmp);
-    // Negative filter has higher precedence than positive filter.
-    if ((hasNegativeFilter &&
-         (NegativePattern.match(FuncName) ||
-          (FunctionSamples::UseMD5 && NegativeMD5Name == FuncName))) ||
-        (hasFilter && !(Pattern.match(FuncName) ||
-                        (FunctionSamples::UseMD5 && MD5Name == FuncName))))
-      ProfileMap.erase(Tmp);
-  }
+  if (!FuncNameFilter.empty())
+    filterFunctions(Profiles, FuncNameFilter, false);
+  if (!FuncNameNegativeFilter.empty())
+    filterFunctions(Profiles, FuncNameNegativeFilter, true);
 
-  llvm::dbgs() << Count - ProfileMap.size() << " of " << Count << " functions "
+  llvm::dbgs() << Count - Profiles.size() << " of " << Count << " functions "
                << "in the original profile are filtered.\n";
 }
 

>From 18318cecd79f0c7ed6618c59dae3a85438253c16 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Sat, 10 Aug 2024 04:04:55 -0400
Subject: [PATCH 2/2] Fixed API design, and used more precise function name
 description Fixed test cases Added more examples to documentation

---
 llvm/docs/CommandGuide/llvm-profdata.rst      |  17 +++
 llvm/include/llvm/ProfileData/SampleProf.h    |  11 +-
 llvm/lib/ProfileData/SampleProf.cpp           | 113 ++++++++----------
 .../tools/llvm-profdata/merge-filter.test     |  32 ++---
 llvm/tools/llvm-profdata/llvm-profdata.cpp    |  26 ++--
 5 files changed, 102 insertions(+), 97 deletions(-)

diff --git a/llvm/docs/CommandGuide/llvm-profdata.rst b/llvm/docs/CommandGuide/llvm-profdata.rst
index af5aeec2df9c62..03e37e9b472148 100644
--- a/llvm/docs/CommandGuide/llvm-profdata.rst
+++ b/llvm/docs/CommandGuide/llvm-profdata.rst
@@ -515,3 +515,20 @@ The canonical name of a call taget is defined as:
 ::
 
     <call target canonical name> ::= <canonical name> ":" <LineLocation> " @@ " <FunctionID>
+
+See following text sample profiles where all canonical names are annotated.
+
+::
+
+    # Contextless profile
+    main:200:0                      # main
+     1: 100 _Z3bari:100             # main:1 @@ _Z3bari
+     10: inline1:100                # main:10 @ inline1
+      11: 1
+      12.3: inline2:99              # main:10 @ inline1:12.3 @ inline2
+       111: 98
+       111.2: 1 _Z3fooi:1           # main:10 @ inline1:12.3 @ inline2:111.2 @@ _Z3fooi
+
+    # CSSPGO
+    [main:1 @ foo:2.3 @ bar]:100:0  # [main:1 @ foo:2.3 @ bar]
+     1.2: 100 baz:100               # [main:1 @ foo:2.3 @ bar]:1.2 @@ baz
diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 7c5ff50ec5206e..bc6ff3100fc456 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -1225,12 +1225,13 @@ class FunctionSamples {
   // all the inline instances and names of call targets.
   void findAllNames(DenseSet<FunctionId> &NameSet) const;
 
-  /// Traverse inlined callsites recursively, and erase those with matching
-  /// canonical representation (or do the opposite, if EraseMatch is false).
+  /// Traverse inlined callsites recursively, and erase call targets in
+  /// BodySamples and callsites in CallsiteSamples with matching canonical
+  /// representation (or do the opposite, if Inverse is true).
   /// Returns total number of samples removed.
-  uint64_t eraseInlinedCallsites(const llvm::Regex &Re,
-                                 std::string &CanonicalName,
-                                 bool MatchCallTargets, bool EraseMatch);
+  uint64_t removeCallTargetsAndCallsites(const llvm::Regex &Re,
+                                         std::string &CanonicalName,
+                                         bool Inverse = false);
 
   bool operator==(const FunctionSamples &Other) const {
     return (GUIDToFuncNameMap == Other.GUIDToFuncNameMap ||
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index c2cf744ca1ca9e..ab3207539efe3a 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -279,96 +279,80 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
   }
 }
 
-namespace {
-// A class to keep a string invariant where it is appended with contents in a
-// recursive function.
-struct ScopedString {
-  std::string &String;
-
-  size_t Length;
-
-  ScopedString(std::string &String) : String(String), Length(String.length()) {}
-
-  ~ScopedString() {
-    assert(String.length() > Length);
-    String.resize(Length);
-  }
-};
-} // namespace
-
 // CanonicalName is invariant in this function. It should already contain the
 // canonical name of the current FunctionSamples.
-uint64_t FunctionSamples::eraseInlinedCallsites(const llvm::Regex &Re,
-                                                std::string &CanonicalName,
-                                                bool MatchCallTargets,
-                                                bool EraseMatch) {
+uint64_t FunctionSamples::removeCallTargetsAndCallsites(
+    const llvm::Regex &Re, std::string &CanonicalName, bool Inverse) {
   uint64_t Result = 0;
-
-  ScopedString SaveString1(CanonicalName);
   CanonicalName.push_back(':');
+  size_t Length1 = CanonicalName.size();
+
+  // Match call targets.
+  for (auto BodySampleIt = BodySamples.begin();
+       BodySampleIt != BodySamples.end();) {
+    auto &[Loc, BodySample] = *BodySampleIt;
+    SampleRecord::CallTargetMap &CallTargets =
+        const_cast<SampleRecord::CallTargetMap &>(BodySample.getCallTargets());
+    if (!CallTargets.empty()) {
+      CanonicalName += Loc.toString();
+      CanonicalName += " @@ ";
+      size_t Length2 = CanonicalName.size();
+
+      uint64_t RemovedCallTargetCount = 0;
+      for (auto CallTargetIt = CallTargets.begin();
+           CallTargetIt != CallTargets.end();) {
+        CanonicalName += CallTargetIt->first.str();
+
+        if (Re.match(CanonicalName) != Inverse) {
+          RemovedCallTargetCount += CallTargetIt->second;
+          CallTargetIt = CallTargets.erase(CallTargetIt);
+        } else
+          ++CallTargetIt;
 
-  if (MatchCallTargets) {
-    // Check matching call targets.
-    for (auto BodySampleIt = BodySamples.begin();
-         BodySampleIt != BodySamples.end();) {
-      auto &[Loc, BodySample] = *BodySampleIt;
-      SampleRecord::CallTargetMap &CallTargets =
-          const_cast<SampleRecord::CallTargetMap &>(
-              BodySample.getCallTargets());
-      if (!CallTargets.empty()) {
-        ScopedString SaveString2(CanonicalName);
-        CanonicalName += Loc.toString();
-        CanonicalName += " @@ ";
-        uint64_t RemovedCallTargetCount = 0;
-        for (auto CallTargetIt = CallTargets.begin();
-             CallTargetIt != CallTargets.end();) {
-          ScopedString SaveString3(CanonicalName);
-          CanonicalName += CallTargetIt->first.str();
-
-          if (Re.match(CanonicalName) == EraseMatch) {
-            RemovedCallTargetCount += CallTargetIt->second;
-            CallTargetIt = CallTargets.erase(CallTargetIt);
-          } else
-            ++CallTargetIt;
-        }
-        // Adjust sample count as if they were removed with
-        // removeCalledTargetAndBodySample.
-        Result += BodySample.removeSamples(RemovedCallTargetCount);
-        if (BodySample.getSamples() == 0) {
-          BodySampleIt = BodySamples.erase(BodySampleIt);
-          continue;
-        }
+        CanonicalName.resize(Length2);
+      }
+      CanonicalName.resize(Length1);
+
+      // Adjust sample count as if they were removed with
+      // removeCalledTargetAndBodySample.
+      Result += BodySample.removeSamples(RemovedCallTargetCount);
+      if (BodySample.getSamples() == 0) {
+        BodySampleIt = BodySamples.erase(BodySampleIt);
+        continue;
       }
-      ++BodySampleIt;
     }
+    ++BodySampleIt;
   }
 
-  // Check matching inlined callsites.
+  assert(CanonicalName.size() == Length1);
+
+  // Match inlined callsites.
   for (auto CallsiteSampleIt = CallsiteSamples.begin();
        CallsiteSampleIt != CallsiteSamples.end();) {
     auto &[Loc, FSMap] = *CallsiteSampleIt;
-    ScopedString SaveString2(CanonicalName);
     CanonicalName += Loc.toString();
     CanonicalName += " @ ";
+    size_t Length2 = CanonicalName.size();
+
     for (auto FunctionSampleIt = FSMap.begin();
          FunctionSampleIt != FSMap.end();) {
       FunctionSamples &InlinedFS = FunctionSampleIt->second;
-      ScopedString SaveString3(CanonicalName);
       CanonicalName.append(InlinedFS.getContext().toString());
 
-      if (Re.match(CanonicalName) == EraseMatch) {
+      if (Re.match(CanonicalName) != Inverse) {
         Result += InlinedFS.getTotalSamples();
         FunctionSampleIt = FSMap.erase(FunctionSampleIt);
       } else {
         // Recursively process inlined callsites.
-        Result += InlinedFS.eraseInlinedCallsites(Re, CanonicalName,
-                                                  MatchCallTargets, EraseMatch);
+        Result +=
+            InlinedFS.removeCallTargetsAndCallsites(Re, CanonicalName, Inverse);
         // If every sample in the inlined callsite is removed, remove the
         // callsite as well.
-        if (InlinedFS.getTotalSamples() == 0) {
+        if (InlinedFS.getTotalSamples() == 0)
           FunctionSampleIt = FSMap.erase(FunctionSampleIt);
-        } else
+        else
           ++FunctionSampleIt;
+        CanonicalName.resize(Length2);
       }
     }
 
@@ -377,8 +361,11 @@ uint64_t FunctionSamples::eraseInlinedCallsites(const llvm::Regex &Re,
       CallsiteSampleIt = CallsiteSamples.erase(CallsiteSampleIt);
     else
       ++CallsiteSampleIt;
+    CanonicalName.resize(Length1);
   }
 
+  assert(CanonicalName.size() == Length1);
+  CanonicalName.pop_back();
   // Adjust total sample count after removals.
   return removeTotalSamples(Result);
 }
diff --git a/llvm/test/tools/llvm-profdata/merge-filter.test b/llvm/test/tools/llvm-profdata/merge-filter.test
index 7bec3f29f3c59e..0643890ae0ca16 100644
--- a/llvm/test/tools/llvm-profdata/merge-filter.test
+++ b/llvm/test/tools/llvm-profdata/merge-filter.test
@@ -31,21 +31,6 @@ CHECK-NEXT:   11: 1000
 CHECK-NEXT:   12: bar:100
 CHECK-NEXT:    121: 200
 
-RUN: llvm-profdata merge --sample %p/Inputs/merge-filter.proftext --text --no-function="main:4 @@ _Z3fibi" | FileCheck %s --check-prefix=CHECK-FILTER-CALLSITE
-CHECK-FILTER-CALLSITE: main:1313:11
-CHECK-NEXT:  0: 6
-CHECK-NEXT:  1: 6
-CHECK-NEXT:  4: 1 _Z3fibv:1
-CHECK-NEXT:  1: foo:1200
-CHECK-NEXT:   11: 1000
-CHECK-NEXT:   12: bar:100
-CHECK-NEXT:    121: 200
-CHECK-NEXT:   13: bat:100
-CHECK-NEXT:    131: 100
-CHECK-NEXT:  2.100001: goo:100
-CHECK-NEXT:   21: bat:100
-CHECK-NEXT:    211: 100
-
 RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --function="foo" | FileCheck %s --check-prefix=CHECK-FILTER3
 RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --no-function="main" | FileCheck %s --check-prefix=CHECK-FILTER3
 CHECK-FILTER3: foo
@@ -66,7 +51,7 @@ CHECK-NEXT: # Counter Values:
 CHECK-NEXT: 500500
 CHECK-NEXT: 180100
 
-RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --function="foo" --no-function="^foo$" | FileCheck %s --check-prefix=CHECK-FILTER4
+RUN: llvm-profdata merge --instr %p/Inputs/basic.proftext --text --function="\"^foo\"" --no-function="\"^foo$\"" | FileCheck %s --check-prefix=CHECK-FILTER4
 CHECK-FILTER4: foo2
 CHECK-NEXT: # Func Hash:
 CHECK-NEXT: 10
@@ -91,3 +76,18 @@ CHECK-NEXT:  0: 19
 CHECK-NEXT:  1: 19 _Z8funcLeafi:20
 CHECK-NEXT:  3: 12
 CHECK-NEXT:  !Attributes: 1
+
+RUN: llvm-profdata merge --sample %p/Inputs/merge-filter.proftext --text --no-function="main:4 @@ _Z3fibi" | FileCheck %s --check-prefix=CHECK-FILTER-CALLSITE
+CHECK-FILTER-CALLSITE: main:1313:11
+CHECK-NEXT:  0: 6
+CHECK-NEXT:  1: 6
+CHECK-NEXT:  4: 1 _Z3fibv:1
+CHECK-NEXT:  1: foo:1200
+CHECK-NEXT:   11: 1000
+CHECK-NEXT:   12: bar:100
+CHECK-NEXT:    121: 200
+CHECK-NEXT:   13: bat:100
+CHECK-NEXT:    131: 100
+CHECK-NEXT:  2.100001: goo:100
+CHECK-NEXT:   21: bat:100
+CHECK-NEXT:    211: 100
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index f4b68c7c97379c..372b9475b87450 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -833,16 +833,15 @@ static void mergeWriterContexts(WriterContext *Dst, WriterContext *Src) {
 // "foo.*bar:1 @ baz" may match "foo:1 @ bar:1 @ baz". The user should specify
 // regex pattern in a way not to match strings that are not valid mangled names.
 static void filterFunctions(SampleProfileMap &Profiles,
-                            std::string FilterString, bool EraseMatch) {
+                            std::string FilterString, bool Inverse) {
   // Checking all call targets is very slow, only do this if FilterString can
   // ever match a call target.
   bool MatchCallTargets = (FilterString.find(" @@ ") != std::string::npos);
 
   // Search inlined callsites recursively is extremely slow, only do this if
   // FilterString has more than one part delimited by " @ " (except for CSSPGO
-  // top level function, we will check for that later) or " @@ ".
-  bool SearchInlinedCallsites =
-      MatchCallTargets || (FilterString.find(" @ ") != std::string::npos);
+  // top level function, we will check for that later).
+  bool SearchInlinedCallsites = (FilterString.find(" @ ") != std::string::npos);
 
   uint64_t MD5 = 0;
 
@@ -885,26 +884,27 @@ static void filterFunctions(SampleProfileMap &Profiles,
       CanonicalName = "[" + CanonicalName + "]";
     if ((Re.match(CanonicalName) ||
          (FunctionSamples::UseMD5 &&
-          FS->second.getContext().getHashCode() == MD5)) == EraseMatch) {
+          FS->second.getContext().getHashCode() == MD5)) != Inverse) {
       FS = Profiles.erase(FS);
       continue;
     }
-    if (SearchInlinedCallsites)
-      FS->second.eraseInlinedCallsites(Re, CanonicalName, MatchCallTargets,
-                                       EraseMatch);
+    // Perform expensive recursive search if the user specifies such pattern.
+    if (MatchCallTargets || SearchInlinedCallsites)
+      FS->second.removeCallTargetsAndCallsites(Re, CanonicalName, Inverse);
     FS++;
   }
 }
 
 static void filterFunctions(StringMap<InstrProfWriter::ProfilingData> &Profiles,
-                            std::string FilterString, bool EraseMatch) {
+                            std::string FilterString, bool Inverse) {
   // If Pattern is quoted string, treat it as escaped regex, otherwise treat it
   // as literal match.
   if (FilterString[0] == '\"') {
     if (FilterString.size() < 2 || FilterString.back() != '\"')
       exitWithError("missing terminating '\"' character");
     FilterString = FilterString.substr(1, FilterString.length() - 2);
-  }
+  } else
+    FilterString = "^" + llvm::Regex::escape(FilterString) + "$";
 
   llvm::Regex Re(FilterString);
   if (std::string Error; !Re.isValid(Error))
@@ -912,7 +912,7 @@ static void filterFunctions(StringMap<InstrProfWriter::ProfilingData> &Profiles,
 
   for (auto ProfileIt = Profiles.begin(); ProfileIt != Profiles.end();) {
     auto Tmp = ProfileIt++;
-    if (Re.match(Tmp->first()) == EraseMatch)
+    if (Re.match(Tmp->first()) != Inverse)
       Profiles.erase(Tmp);
   }
 }
@@ -926,9 +926,9 @@ template <typename T> static void filterFunctions(T &Profiles) {
   size_t Count = Profiles.size();
 
   if (!FuncNameFilter.empty())
-    filterFunctions(Profiles, FuncNameFilter, false);
+    filterFunctions(Profiles, FuncNameFilter, true);
   if (!FuncNameNegativeFilter.empty())
-    filterFunctions(Profiles, FuncNameNegativeFilter, true);
+    filterFunctions(Profiles, FuncNameNegativeFilter, false);
 
   llvm::dbgs() << Count - Profiles.size() << " of " << Count << " functions "
                << "in the original profile are filtered.\n";



More information about the llvm-commits mailing list