[llvm] 6047deb - [llvm] Provide utility function for MD_prof
Paul Kirth via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 27 14:14:03 PDT 2022
Author: Paul Kirth
Date: 2022-07-27T21:13:51Z
New Revision: 6047deb7c2aa94d9bc2b70b49799d22cce778bd4
URL: https://github.com/llvm/llvm-project/commit/6047deb7c2aa94d9bc2b70b49799d22cce778bd4
DIFF: https://github.com/llvm/llvm-project/commit/6047deb7c2aa94d9bc2b70b49799d22cce778bd4.diff
LOG: [llvm] Provide utility function for MD_prof
Currently, there is significant code duplication for dealing with
MD_prof metadata throughout the compiler. These utility functions can
improve code reuse and simplify boilerplate code when dealing with
profiling metadata, such as branch weights. The inent is to provide a
uniform set of APIs that allow common tasks, such as identifying
specific types of MD_prof metadata and extracting branch weights.
Future patches can build on this initial implementation and clean up the
different implementations across the compiler.
Reviewed By: bogner
Differential Revision: https://reviews.llvm.org/D128858
Added:
llvm/include/llvm/IR/ProfDataUtils.h
llvm/lib/IR/ProfDataUtils.cpp
llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp
Modified:
llvm/lib/IR/CMakeLists.txt
llvm/unittests/Transforms/Utils/CMakeLists.txt
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
new file mode 100644
index 0000000000000..86ea18d7f1712
--- /dev/null
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -0,0 +1,56 @@
+#ifndef PROF_DATA_UTILS_H
+#define PROF_DATA_UTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Metadata.h"
+
+namespace llvm {
+
+/// Checks if an Instruction has MD_prof Metadata
+bool hasProfMD(const Instruction &I);
+
+/// Checks if an MDNode contains Branch Weight Metadata
+bool isBranchWeightMD(const MDNode *ProfileData);
+
+/// Checks if an instructions has Branch Weight Metadata
+///
+/// \param I The instruction to check
+/// \return True if I has an MD_prof node containing Branch Weights. False
+/// otherwise.
+bool hasBranchWeightMD(const Instruction &I);
+
+/// Extract branch weights from MD_prof metadata
+///
+/// \param ProfileData A pointer to an MDNode.
+/// \param Weights An output vector to fill with branch weights
+/// \return True if weights were extracted, False otherwise. When false Weights
+/// will be cleared.
+bool extractBranchWeights(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights);
+
+/// Extract branch weights attatched to an Instruction
+///
+/// \param I The Instruction to extract weights from.
+/// \param Weights An output vector to fill with branch weights
+/// \return True if weights were extracted, False otherwise. When false Weights
+/// will be cleared.
+bool extractBranchWeights(const Instruction &I,
+ SmallVectorImpl<uint32_t> &Weights);
+
+/// Retrieve the raw weight values of a conditional branch or select.
+/// Returns true on success with profile weights filled in.
+/// Returns false if no metadata or invalid metadata was found.
+bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
+ uint64_t &FalseVal);
+
+/// Retrieve the total of all weights from MD_prof data.
+///
+/// \param ProfileData The profile data to extract the total weight from
+/// \param TotalWeights input variable to fill with total weights
+/// \return true on success with profile total weights filled in.
+/// \return false if no metadata was found.
+bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
+
+} // namespace llvm
+#endif
diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 3e542e4622fbd..13a28f51bcd61 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -49,6 +49,7 @@ add_llvm_component_library(LLVMCore
PassRegistry.cpp
PassTimingInfo.cpp
PrintPasses.cpp
+ ProfDataUtils.cpp
SafepointIRVerifier.cpp
ProfileSummary.cpp
PseudoProbe.cpp
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
new file mode 100644
index 0000000000000..d0550b774d4c8
--- /dev/null
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -0,0 +1,150 @@
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace llvm;
+
+namespace {
+
+// MD_prof nodes have the following layout
+//
+// In general:
+// { String name, Array of i32 }
+//
+// In terms of Types:
+// { MDString, [i32, i32, ...]}
+//
+// Concretely for Branch Weights
+// { "branch_weights", [i32 1, i32 10000]}
+//
+// We maintain some constants here to ensure that we access the branch weights
+// correctly, and can change the behavior in the future if the layout changes
+
+// The index at which the weights vector starts
+constexpr unsigned WeightsIdx = 1;
+
+// the minimum number of operands for MD_prof nodes with branch weights
+constexpr unsigned MinBWOps = 3;
+
+bool extractWeights(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
+ // Assume preconditions are already met (i.e. this is valid metadata)
+ assert(ProfileData && "ProfileData was nullptr in extractWeights");
+ unsigned NOps = ProfileData->getNumOperands();
+
+ assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+ Weights.resize(NOps - WeightsIdx);
+
+ for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+ ConstantInt *Weight =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+ assert(Weight && "Malformed branch_weight in MD_prof node");
+ assert(Weight->getValue().getActiveBits() <= 32 &&
+ "Too many bits for uint32_t");
+ Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+ }
+ return true;
+}
+
+// We may want to add support for other MD_prof types, so provide an abstraction
+// for checking the metadata type.
+bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
+ // TODO: This routine may be simplified if MD_prof used an enum instead of a
+ // string to
diff erentiate the types of MD_prof nodes.
+ if (!ProfData || !Name || MinOps < 2)
+ return false;
+
+ unsigned NOps = ProfData->getNumOperands();
+ if (NOps < MinOps)
+ return false;
+
+ auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
+ if (!ProfDataName)
+ return false;
+
+ return ProfDataName->getString().equals(Name);
+}
+
+} // namespace
+
+namespace llvm {
+
+bool hasProfMD(const Instruction &I) {
+ return nullptr != I.getMetadata(LLVMContext::MD_prof);
+}
+
+bool isBranchWeightMD(const MDNode *ProfileData) {
+ return isTargetMD(ProfileData, "branch_weights", MinBWOps);
+}
+
+bool hasBranchWeightMD(const Instruction &I) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ return isBranchWeightMD(ProfileData);
+}
+
+bool extractBranchWeights(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
+ if (!isBranchWeightMD(ProfileData))
+ return false;
+ return extractWeights(ProfileData, Weights);
+}
+
+bool extractBranchWeights(const Instruction &I,
+ SmallVectorImpl<uint32_t> &Weights) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ return extractBranchWeights(ProfileData, Weights);
+}
+
+bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
+ uint64_t &FalseVal) {
+ assert((I.getOpcode() == Instruction::Br ||
+ I.getOpcode() == Instruction::Select) &&
+ "Looking for branch weights on something besides branch or select");
+
+ SmallVector<uint32_t, 2> Weights;
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ if (!extractBranchWeights(ProfileData, Weights))
+ return false;
+
+ if (Weights.size() > 2)
+ return false;
+
+ TrueVal = Weights[0];
+ FalseVal = Weights[1];
+ return true;
+}
+
+bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
+ TotalVal = 0;
+ if (!ProfileData)
+ return false;
+
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+ if (!ProfDataName)
+ return false;
+
+ if (ProfDataName->getString().equals("branch_weights")) {
+ for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+ auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+ assert(V && "Malformed branch_weight in MD_prof node");
+ TotalVal += V->getValue().getZExtValue();
+ }
+ return true;
+ } else if (ProfDataName->getString().equals("VP") &&
+ ProfileData->getNumOperands() > 3) {
+ TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
+ ->getValue()
+ .getZExtValue();
+ return true;
+ }
+ return false;
+}
+
+} // namespace llvm
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index b714c8eabed92..751de6fc5becc 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_unittest(UtilsTests
UnrollLoopTest.cpp
ValueMapperTest.cpp
VFABIUtils.cpp
+ ProfDataUtilTest.cpp
)
set_property(TARGET UtilsTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")
diff --git a/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp
new file mode 100644
index 0000000000000..f32108a7e058d
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp
@@ -0,0 +1,93 @@
+//===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/BreakCriticalEdges.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+ if (!Mod)
+ Err.print("ProfDataUtilsTests", errs());
+ return Mod;
+}
+
+TEST(ProfDataUtils, extractWeights) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @foo(i1 %cond0) {
+entry:
+ br i1 %cond0, label %bb0, label %bb1, !prof !1
+bb0:
+ %0 = mul i32 1, 2
+ br label %bb1
+bb1:
+ ret void
+}
+
+!1 = !{!"branch_weights", i32 1, i32 100000}
+)IR");
+ Function *F = M->getFunction("foo");
+ auto &Entry = F->getEntryBlock();
+ auto &I = Entry.front();
+ auto *Branch = dyn_cast<BranchInst>(&I);
+ EXPECT_NE(nullptr, Branch);
+ auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
+ EXPECT_NE(ProfileData, nullptr);
+ EXPECT_TRUE(hasProfMD(I));
+ SmallVector<uint32_t> Weights;
+ EXPECT_TRUE(extractBranchWeights(ProfileData, Weights));
+ EXPECT_EQ(Weights[0], 1U);
+ EXPECT_EQ(Weights[1], 100000U);
+ EXPECT_EQ(Weights.size(), 2U);
+}
+
+TEST(ProfDataUtils, NoWeights) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @foo(i1 %cond0) {
+entry:
+ br i1 %cond0, label %bb0, label %bb1
+bb0:
+ %0 = mul i32 1, 2
+ br label %bb1
+bb1:
+ ret void
+}
+)IR");
+ Function *F = M->getFunction("foo");
+ auto &Entry = F->getEntryBlock();
+ auto &I = Entry.front();
+ auto *Branch = dyn_cast<BranchInst>(&I);
+ EXPECT_NE(nullptr, Branch);
+ auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
+ EXPECT_EQ(ProfileData, nullptr);
+ EXPECT_FALSE(hasProfMD(I));
+ SmallVector<uint32_t> Weights;
+ EXPECT_FALSE(extractBranchWeights(ProfileData, Weights));
+ EXPECT_EQ(Weights.size(), 0U);
+}
More information about the llvm-commits
mailing list