[llvm] r361596 - [NFC] SwitchInst: Introduce wrapper for prof branch_weights handling

Yevgeny Rouban via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 21:34:24 PDT 2019


Author: yrouban
Date: Thu May 23 21:34:23 2019
New Revision: 361596

URL: http://llvm.org/viewvc/llvm-project?rev=361596&view=rev
Log:
[NFC] SwitchInst: Introduce wrapper for prof branch_weights handling

This patch introduces a wrapper class that re-implements
several mutator methods of SwitchInst to handle changes
of prof branch_weights metadata along with remove/add
switch case methods.
Subsequent patches will use this wrapper to implement
prof branch_weights metadata handling for SwitchInst.

Reviewers: davidx, eraman, reames, chandlerc
Reviewed By: davidx
Differential Revision: https://reviews.llvm.org/D62122

Modified:
    llvm/trunk/include/llvm/IR/Instructions.h
    llvm/trunk/lib/IR/Instructions.cpp

Modified: llvm/trunk/include/llvm/IR/Instructions.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/Instructions.h?rev=361596&r1=361595&r2=361596&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/Instructions.h (original)
+++ llvm/trunk/include/llvm/IR/Instructions.h Thu May 23 21:34:23 2019
@@ -3435,6 +3435,52 @@ public:
   }
 };
 
+/// A wrapper class to simplify modification of SwitchInst cases along with
+/// their prof branch_weights metadata.
+class SwitchInstProfUpdateWrapper {
+  SwitchInst &SI;
+  Optional<SmallVector<uint32_t, 8> > Weights;
+  bool Changed = false;
+
+protected:
+  static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
+
+  MDNode *buildProfBranchWeightsMD();
+
+  Optional<SmallVector<uint32_t, 8> > getProfBranchWeights();
+
+public:
+  using CaseWeightOpt = Optional<uint32_t>;
+  SwitchInst *operator->() { return &SI; }
+  SwitchInst &operator*() { return SI; }
+  operator SwitchInst *() { return &SI; }
+
+  SwitchInstProfUpdateWrapper(SwitchInst &SI)
+      : SI(SI), Weights(getProfBranchWeights()) {}
+
+  ~SwitchInstProfUpdateWrapper() {
+    if (Changed)
+      SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
+  }
+
+  /// Delegate the call to the underlying SwitchInst::removeCase() and remove
+  /// correspondent branch weight.
+  SwitchInst::CaseIt removeCase(SwitchInst::CaseIt I);
+
+  /// Delegate the call to the underlying SwitchInst::addCase() and set the
+  /// specified branch weight for the added case.
+  void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W);
+
+  /// Delegate the call to the underlying SwitchInst::eraseFromParent() and mark
+  /// this object to not touch the underlying SwitchInst in destructor.
+  SymbolTableList<Instruction>::iterator eraseFromParent();
+
+  void setSuccessorWeight(unsigned idx, CaseWeightOpt W);
+  CaseWeightOpt getSuccessorWeight(unsigned idx);
+
+  static CaseWeightOpt getSuccessorWeight(const SwitchInst &SI, unsigned idx);
+};
+
 template <>
 struct OperandTraits<SwitchInst> : public HungoffOperandTraits<2> {
 };

Modified: llvm/trunk/lib/IR/Instructions.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/IR/Instructions.cpp?rev=361596&r1=361595&r2=361596&view=diff
==============================================================================
--- llvm/trunk/lib/IR/Instructions.cpp (original)
+++ llvm/trunk/lib/IR/Instructions.cpp Thu May 23 21:34:23 2019
@@ -3870,6 +3870,126 @@ void SwitchInst::growOperands() {
   growHungoffUses(ReservedSpace);
 }
 
+MDNode *
+SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) {
+  if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof))
+    if (auto *MDName = dyn_cast<MDString>(ProfileData->getOperand(0)))
+      if (MDName->getString() == "branch_weights")
+        return ProfileData;
+  return nullptr;
+}
+
+MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
+  assert(Changed && "called only if metadata has changed");
+
+  if (!Weights)
+    return nullptr;
+
+  assert(SI.getNumSuccessors() == Weights->size() &&
+         "num of prof branch_weights must accord with num of successors");
+
+  bool AllZeroes =
+      all_of(Weights.getValue(), [](uint32_t W) { return W == 0; });
+
+  if (AllZeroes || Weights.getValue().size() < 2)
+    return nullptr;
+
+  return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
+}
+
+Optional<SmallVector<uint32_t, 8> >
+SwitchInstProfUpdateWrapper::getProfBranchWeights() {
+  MDNode *ProfileData = getProfBranchWeightsMD(SI);
+  if (!ProfileData)
+    return None;
+
+  SmallVector<uint32_t, 8> Weights;
+  for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
+    ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
+    uint32_t CW = C->getValue().getZExtValue();
+    Weights.push_back(CW);
+  }
+  return Weights;
+}
+
+SwitchInst::CaseIt
+SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) {
+  if (Weights) {
+    assert(SI.getNumSuccessors() == Weights->size() &&
+           "num of prof branch_weights must accord with num of successors");
+    Changed = true;
+    // Copy the last case to the place of the removed one and shrink.
+    // This is tightly coupled with the way SwitchInst::removeCase() removes
+    // the cases in SwitchInst::removeCase(CaseIt).
+    Weights.getValue()[I->getCaseIndex() + 1] = Weights.getValue().back();
+    Weights.getValue().pop_back();
+  }
+  return SI.removeCase(I);
+}
+
+void SwitchInstProfUpdateWrapper::addCase(
+    ConstantInt *OnVal, BasicBlock *Dest,
+    SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
+  SI.addCase(OnVal, Dest);
+
+  if (!Weights && W && *W) {
+    Changed = true;
+    Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+    Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
+  } else if (Weights) {
+    Changed = true;
+    Weights.getValue().push_back(W ? *W : 0);
+  }
+  if (Weights)
+    assert(SI.getNumSuccessors() == Weights->size() &&
+           "num of prof branch_weights must accord with num of successors");
+}
+
+SymbolTableList<Instruction>::iterator
+SwitchInstProfUpdateWrapper::eraseFromParent() {
+  // Instruction is erased. Mark as unchanged to not touch it in the destructor.
+  Changed = false;
+
+  if (Weights)
+    Weights->resize(0);
+  return SI.eraseFromParent();
+}
+
+SwitchInstProfUpdateWrapper::CaseWeightOpt
+SwitchInstProfUpdateWrapper::getSuccessorWeight(unsigned idx) {
+  if (!Weights)
+    return None;
+  return Weights.getValue()[idx];
+}
+
+void SwitchInstProfUpdateWrapper::setSuccessorWeight(
+    unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
+  if (!W)
+    return;
+
+  if (!Weights && *W)
+    Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+
+  if (Weights) {
+    auto &OldW = Weights.getValue()[idx];
+    if (*W != OldW) {
+      Changed = true;
+      OldW = *W;
+    }
+  }
+}
+
+SwitchInstProfUpdateWrapper::CaseWeightOpt
+SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
+                                                unsigned idx) {
+  if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
+    return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
+        ->getValue()
+        .getZExtValue();
+
+  return None;
+}
+
 //===----------------------------------------------------------------------===//
 //                        IndirectBrInst Implementation
 //===----------------------------------------------------------------------===//




More information about the llvm-commits mailing list