[llvm] 6047deb - [llvm] Provide utility function for MD_prof

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 27 14:14:03 PDT 2022


Author: Paul Kirth
Date: 2022-07-27T21:13:51Z
New Revision: 6047deb7c2aa94d9bc2b70b49799d22cce778bd4

URL: https://github.com/llvm/llvm-project/commit/6047deb7c2aa94d9bc2b70b49799d22cce778bd4
DIFF: https://github.com/llvm/llvm-project/commit/6047deb7c2aa94d9bc2b70b49799d22cce778bd4.diff

LOG: [llvm] Provide utility function for MD_prof

Currently, there is significant code duplication for dealing with
MD_prof metadata throughout the compiler. These utility functions can
improve code reuse and simplify boilerplate code when dealing with
profiling metadata, such as branch weights. The inent is to provide a
uniform set of APIs that allow common tasks, such as identifying
specific types of MD_prof metadata and extracting branch weights.

Future patches can build on this initial implementation and clean up the
different implementations across the compiler.

Reviewed By: bogner

Differential Revision: https://reviews.llvm.org/D128858

Added: 
    llvm/include/llvm/IR/ProfDataUtils.h
    llvm/lib/IR/ProfDataUtils.cpp
    llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp

Modified: 
    llvm/lib/IR/CMakeLists.txt
    llvm/unittests/Transforms/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
new file mode 100644
index 0000000000000..86ea18d7f1712
--- /dev/null
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -0,0 +1,56 @@
+#ifndef PROF_DATA_UTILS_H
+#define PROF_DATA_UTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Metadata.h"
+
+namespace llvm {
+
+/// Checks if an Instruction has MD_prof Metadata
+bool hasProfMD(const Instruction &I);
+
+/// Checks if an MDNode contains Branch Weight Metadata
+bool isBranchWeightMD(const MDNode *ProfileData);
+
+/// Checks if an instructions has Branch Weight Metadata
+///
+/// \param I The instruction to check
+/// \return True if I has an MD_prof node containing Branch Weights. False
+/// otherwise.
+bool hasBranchWeightMD(const Instruction &I);
+
+/// Extract branch weights from MD_prof metadata
+///
+/// \param ProfileData A pointer to an MDNode.
+/// \param Weights An output vector to fill with branch weights
+/// \return True if weights were extracted, False otherwise. When false Weights
+/// will be cleared.
+bool extractBranchWeights(const MDNode *ProfileData,
+                          SmallVectorImpl<uint32_t> &Weights);
+
+/// Extract branch weights attatched to an Instruction
+///
+/// \param I The Instruction to extract weights from.
+/// \param Weights An output vector to fill with branch weights
+/// \return True if weights were extracted, False otherwise. When false Weights
+/// will be cleared.
+bool extractBranchWeights(const Instruction &I,
+                          SmallVectorImpl<uint32_t> &Weights);
+
+/// Retrieve the raw weight values of a conditional branch or select.
+/// Returns true on success with profile weights filled in.
+/// Returns false if no metadata or invalid metadata was found.
+bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
+                          uint64_t &FalseVal);
+
+/// Retrieve the total of all weights from MD_prof data.
+///
+/// \param ProfileData The profile data to extract the total weight from
+/// \param TotalWeights input variable to fill with total weights
+/// \return true on success with profile total weights filled in.
+/// \return false if no metadata was found.
+bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
+
+} // namespace llvm
+#endif

diff  --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 3e542e4622fbd..13a28f51bcd61 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -49,6 +49,7 @@ add_llvm_component_library(LLVMCore
   PassRegistry.cpp
   PassTimingInfo.cpp
   PrintPasses.cpp
+  ProfDataUtils.cpp
   SafepointIRVerifier.cpp
   ProfileSummary.cpp
   PseudoProbe.cpp

diff  --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
new file mode 100644
index 0000000000000..d0550b774d4c8
--- /dev/null
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -0,0 +1,150 @@
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace llvm;
+
+namespace {
+
+// MD_prof nodes have the following layout
+//
+// In general:
+// { String name,         Array of i32   }
+//
+// In terms of Types:
+// { MDString,            [i32, i32, ...]}
+//
+// Concretely for Branch Weights
+// { "branch_weights",    [i32 1, i32 10000]}
+//
+// We maintain some constants here to ensure that we access the branch weights
+// correctly, and can change the behavior in the future if the layout changes
+
+// The index at which the weights vector starts
+constexpr unsigned WeightsIdx = 1;
+
+// the minimum number of operands for MD_prof nodes with branch weights
+constexpr unsigned MinBWOps = 3;
+
+bool extractWeights(const MDNode *ProfileData,
+                    SmallVectorImpl<uint32_t> &Weights) {
+  // Assume preconditions are already met (i.e. this is valid metadata)
+  assert(ProfileData && "ProfileData was nullptr in extractWeights");
+  unsigned NOps = ProfileData->getNumOperands();
+
+  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+  Weights.resize(NOps - WeightsIdx);
+
+  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+    ConstantInt *Weight =
+        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+    assert(Weight && "Malformed branch_weight in MD_prof node");
+    assert(Weight->getValue().getActiveBits() <= 32 &&
+           "Too many bits for uint32_t");
+    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+  }
+  return true;
+}
+
+// We may want to add support for other MD_prof types, so provide an abstraction
+// for checking the metadata type.
+bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
+  // TODO: This routine may be simplified if MD_prof used an enum instead of a
+  // string to 
diff erentiate the types of MD_prof nodes.
+  if (!ProfData || !Name || MinOps < 2)
+    return false;
+
+  unsigned NOps = ProfData->getNumOperands();
+  if (NOps < MinOps)
+    return false;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
+  if (!ProfDataName)
+    return false;
+
+  return ProfDataName->getString().equals(Name);
+}
+
+} // namespace
+
+namespace llvm {
+
+bool hasProfMD(const Instruction &I) {
+  return nullptr != I.getMetadata(LLVMContext::MD_prof);
+}
+
+bool isBranchWeightMD(const MDNode *ProfileData) {
+  return isTargetMD(ProfileData, "branch_weights", MinBWOps);
+}
+
+bool hasBranchWeightMD(const Instruction &I) {
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  return isBranchWeightMD(ProfileData);
+}
+
+bool extractBranchWeights(const MDNode *ProfileData,
+                          SmallVectorImpl<uint32_t> &Weights) {
+  if (!isBranchWeightMD(ProfileData))
+    return false;
+  return extractWeights(ProfileData, Weights);
+}
+
+bool extractBranchWeights(const Instruction &I,
+                          SmallVectorImpl<uint32_t> &Weights) {
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  return extractBranchWeights(ProfileData, Weights);
+}
+
+bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
+                          uint64_t &FalseVal) {
+  assert((I.getOpcode() == Instruction::Br ||
+          I.getOpcode() == Instruction::Select) &&
+         "Looking for branch weights on something besides branch or select");
+
+  SmallVector<uint32_t, 2> Weights;
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  if (!extractBranchWeights(ProfileData, Weights))
+    return false;
+
+  if (Weights.size() > 2)
+    return false;
+
+  TrueVal = Weights[0];
+  FalseVal = Weights[1];
+  return true;
+}
+
+bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
+  TotalVal = 0;
+  if (!ProfileData)
+    return false;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+  if (!ProfDataName)
+    return false;
+
+  if (ProfDataName->getString().equals("branch_weights")) {
+    for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+      assert(V && "Malformed branch_weight in MD_prof node");
+      TotalVal += V->getValue().getZExtValue();
+    }
+    return true;
+  } else if (ProfDataName->getString().equals("VP") &&
+             ProfileData->getNumOperands() > 3) {
+    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
+                   ->getValue()
+                   .getZExtValue();
+    return true;
+  }
+  return false;
+}
+
+} // namespace llvm

diff  --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index b714c8eabed92..751de6fc5becc 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_unittest(UtilsTests
   UnrollLoopTest.cpp
   ValueMapperTest.cpp
   VFABIUtils.cpp
+  ProfDataUtilTest.cpp
   )
 
 set_property(TARGET UtilsTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")

diff  --git a/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp
new file mode 100644
index 0000000000000..f32108a7e058d
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp
@@ -0,0 +1,93 @@
+//===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/BreakCriticalEdges.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("ProfDataUtilsTests", errs());
+  return Mod;
+}
+
+TEST(ProfDataUtils, extractWeights) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @foo(i1 %cond0) {
+entry:
+  br i1 %cond0, label %bb0, label %bb1, !prof !1
+bb0:
+ %0 = mul i32 1, 2
+ br label %bb1
+bb1:
+  ret void
+}
+
+!1 = !{!"branch_weights", i32 1, i32 100000}
+)IR");
+  Function *F = M->getFunction("foo");
+  auto &Entry = F->getEntryBlock();
+  auto &I = Entry.front();
+  auto *Branch = dyn_cast<BranchInst>(&I);
+  EXPECT_NE(nullptr, Branch);
+  auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
+  EXPECT_NE(ProfileData, nullptr);
+  EXPECT_TRUE(hasProfMD(I));
+  SmallVector<uint32_t> Weights;
+  EXPECT_TRUE(extractBranchWeights(ProfileData, Weights));
+  EXPECT_EQ(Weights[0], 1U);
+  EXPECT_EQ(Weights[1], 100000U);
+  EXPECT_EQ(Weights.size(), 2U);
+}
+
+TEST(ProfDataUtils, NoWeights) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @foo(i1 %cond0) {
+entry:
+  br i1 %cond0, label %bb0, label %bb1
+bb0:
+ %0 = mul i32 1, 2
+ br label %bb1
+bb1:
+  ret void
+}
+)IR");
+  Function *F = M->getFunction("foo");
+  auto &Entry = F->getEntryBlock();
+  auto &I = Entry.front();
+  auto *Branch = dyn_cast<BranchInst>(&I);
+  EXPECT_NE(nullptr, Branch);
+  auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
+  EXPECT_EQ(ProfileData, nullptr);
+  EXPECT_FALSE(hasProfMD(I));
+  SmallVector<uint32_t> Weights;
+  EXPECT_FALSE(extractBranchWeights(ProfileData, Weights));
+  EXPECT_EQ(Weights.size(), 0U);
+}


        


More information about the llvm-commits mailing list