[llvm] [GVNHoist] - Split Parent when profitable to do so. (PR #106842)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 31 03:13:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Pawan Nirpal (pawan-nirpal-031)
<details>
<summary>Changes</summary>
GVNHoist currently bails out unless a value is fully anticipable at a dominator.
It may be profitable to split the parents and create a common dominator for
all the basic blocks that have a fully anticipable value.
Ref [https://github.com/llvm/llvm-project/issues/91665]
---
Full diff: https://github.com/llvm/llvm-project/pull/106842.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/GVNHoist.cpp (+253)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
index b5333c532280ca..13d67e9e212039 100644
--- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp
+++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp
@@ -57,6 +57,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Use.h"
@@ -74,10 +75,15 @@
#include <memory>
#include <utility>
#include <vector>
+#include <map>
+#include <set>
using namespace llvm;
#define DEBUG_TYPE "gvn-hoist"
+using std::vector;
+using std::map;
+using std::set;
STATISTIC(NumHoisted, "Number of instructions hoisted");
STATISTIC(NumRemoved, "Number of instructions removed");
@@ -1139,6 +1145,252 @@ std::pair<unsigned, unsigned> GVNHoist::hoist(HoistingPointList &HPL) {
return {NI, NL + NC + NS};
}
+struct SwitchTreeNode {
+ /*
+ If this is root node then list(<dummyVN, switchInst>) or else it will be a
+ list of list(<V1,I1>, <V2,I2>, <v3,I3>... so on) for each basic block in
+ switch region.
+ */
+ std::map<int, Instruction *> ValueInstList;
+ /*
+ Only root node will have children. They will be branches to switch case
+ blocks.
+ */
+ std::vector<SwitchTreeNode *> children;
+};
+
+// VNtoInsns => [VN -> list(instrs)]
+SwitchTreeNode *BuildSwitchtree(SwitchInst *SwI, const VNtoInsns &ScalarsMap) {
+ Instruction *SwIinst = dyn_cast<Instruction>(SwI);
+ SwitchTreeNode *SwRoot = new SwitchTreeNode();
+ SwRoot->ValueInstList.insert({-10000, SwIinst});
+ for (auto &Case : SwI->cases()) {
+ BasicBlock *CaseBB = Case.getCaseSuccessor();
+ SwitchTreeNode *child = new SwitchTreeNode();
+ for (Instruction &Inst : *CaseBB) {
+ for (auto tuple : ScalarsMap) {
+ unsigned VN = tuple.first.first;
+ for (Instruction *VnInst : tuple.second)
+ if (Inst.isIdenticalTo(VnInst) &&
+ VnInst->getParent() == Inst.getParent())
+ child->ValueInstList.insert(std::make_pair((int)VN, VnInst));
+ }
+ }
+ SwRoot->children.push_back(child);
+ }
+ return SwRoot;
+}
+
+unsigned computeLocalSetDensity(SwitchTreeNode *SwitchRoot,
+ vector<unsigned> &CandVNSet) {
+ unsigned NumOfUseBlocks = 0;
+ // Walk the SwitchTree and record the number of BBs in which all the values
+ // are a part of.
+ for (SwitchTreeNode *Child : SwitchRoot->children) {
+ // Were all the values in candidate used in this block? If yes count it.
+ bool AllValuesUsed = true;
+ for (unsigned &VN : CandVNSet) {
+ if (Child->ValueInstList.find(VN) == Child->ValueInstList.end()) {
+ AllValuesUsed = false;
+ break;
+ }
+ }
+ if (AllValuesUsed)
+ NumOfUseBlocks += 1;
+ }
+ return CandVNSet.size() * NumOfUseBlocks;
+}
+
+// For now, Brute force, all possible sets and compute density, density = (total
+// number of value numbers in candidate hoist set) * (number of BBs the
+// candidate set spans across). Choose a set with maximum density.
+vector<unsigned> ChooseHoistableCandidates(SwitchTreeNode *SwitchRoot) {
+ assert(SwitchRoot != nullptr && "Null Switch Root found");
+ vector<unsigned> UniqueVNs;
+ for (SwitchTreeNode *Child : SwitchRoot->children) {
+ for (auto Tuple : Child->ValueInstList)
+ if (std::find(UniqueVNs.begin(), UniqueVNs.end(), Tuple.first) ==
+ UniqueVNs.end())
+ UniqueVNs.push_back(Tuple.first);
+ }
+
+ // Brute force limit.
+ unsigned limit = 10;
+ if (UniqueVNs.size() > limit)
+ return {};
+
+ unsigned TotalSets = (1 << UniqueVNs.size());
+ vector<unsigned> Candidates;
+ unsigned MaxDensity = 0;
+ for (unsigned i = 0; i < TotalSets; i++) {
+ // Populate the local set.
+ vector<unsigned> LocalSet;
+ for (unsigned j = 0; j < UniqueVNs.size(); j++) {
+ unsigned s = UniqueVNs.size();
+ if (i & (1 << (s - 1 - j)))
+ LocalSet.push_back(UniqueVNs[j]);
+ }
+ // Evaluate the local set.
+ unsigned LocalSetDensity = computeLocalSetDensity(SwitchRoot, LocalSet);
+ if (LocalSetDensity > MaxDensity) {
+ MaxDensity = LocalSetDensity;
+ Candidates = LocalSet;
+ }
+ }
+ return Candidates;
+}
+
+DenseMap<unsigned, SmallVector<Instruction *, 4>>
+getPureVNToInstrsMap(const VNtoInsns &ScalarsMap) {
+ DenseMap<unsigned, SmallVector<Instruction *, 4>> VNToInstrsMap;
+ for (auto Tuple : ScalarsMap) {
+ unsigned VN = Tuple.first.first;
+ VNToInstrsMap[VN] = Tuple.second;
+ }
+ return VNToInstrsMap;
+}
+
+set<unsigned> getVNsInBasicBlock(BasicBlock *BB, const VNtoInsns &ScalarsMap) {
+ DenseMap<unsigned, SmallVector<Instruction *, 4>> VNToInstrsMap =
+ getPureVNToInstrsMap(ScalarsMap);
+ set<unsigned> VNsInBB;
+ for (auto Tuple : VNToInstrsMap) {
+ unsigned VN = Tuple.first;
+ for (Instruction *Inst : Tuple.second)
+ if (Inst->getParent() == BB)
+ VNsInBB.insert(VN);
+ }
+ return VNsInBB;
+}
+
+map<Instruction *, vector<Instruction *>>
+buildHoistToOldInstrsMapForBB(BasicBlock &TBB,
+ vector<Instruction *> InstrsToHoist) {
+ map<Instruction *, vector<Instruction *>> HoistToOldInstrMap;
+ for (Instruction *HI : InstrsToHoist) {
+ for (Instruction &OI : TBB) {
+ if (HI->isIdenticalTo(&OI)) {
+ HoistToOldInstrMap[HI].push_back(&OI);
+ }
+ }
+ }
+ return HoistToOldInstrMap;
+}
+
+bool performSwitchHoist(SwitchTreeNode *SwitchRoot, const VNtoInsns &ScalarsMap,
+ const vector<unsigned> HoistCand) {
+
+ // Aquire target BBs, by checking If the Target BB has all the Candidate VNs.
+ set<std::pair<BasicBlock *, ConstantInt *>> Targets;
+
+ Instruction *SI = SwitchRoot->ValueInstList.begin()->second;
+ Function &F = *SI->getFunction();
+ assert(SI != nullptr && "Null Switch Root Found");
+ SwitchInst *SwInst = dyn_cast<SwitchInst>(SI);
+
+ for (auto &Case : SwInst->cases()) {
+ BasicBlock *CaseBB = Case.getCaseSuccessor();
+ bool AllVNsExist = true;
+ set<unsigned> VNsForBB = getVNsInBasicBlock(CaseBB, ScalarsMap);
+ for (unsigned VN : HoistCand) {
+ if (VNsForBB.find(VN) == VNsForBB.end()) {
+ AllVNsExist = false;
+ break;
+ }
+ }
+ if (AllVNsExist)
+ Targets.insert({CaseBB, Case.getCaseValue()});
+ }
+
+ // Collect the values to be hoisted.
+ set<Value *> HoistValues;
+ set<unsigned> Visited;
+ for (auto Tuple : ScalarsMap) {
+ unsigned VN = Tuple.first.first;
+ // If a VN is a hoist candidate. Then get it's instructions.
+ if (std::find(HoistCand.begin(), HoistCand.end(), VN) != HoistCand.end()) {
+ for (Instruction *Inst : Tuple.second) {
+ Value *HVal = dyn_cast<Value>(Inst);
+ if (HoistValues.find(HVal) == HoistValues.end() &&
+ Visited.find(VN) == Visited.end()) {
+ HoistValues.insert(HVal);
+ Visited.insert(VN);
+ }
+ }
+ }
+ }
+
+ // Hoist the values into this new Hoist Block.
+ LLVMContext &Context = F.getContext();
+ BasicBlock *HoistBlock = BasicBlock::Create(Context, "hoist.block", &F);
+ BasicBlock *OrigSwitchBlock = SwInst->getParent();
+ vector<Instruction *> InstrsToHoist;
+ for (Value *V : HoistValues) {
+ Instruction *OI = dyn_cast<Instruction>(V);
+ Instruction *HI = OI->clone();
+ InstrsToHoist.push_back(HI);
+ }
+
+ IRBuilder<> Builder(HoistBlock);
+ for (Instruction *I : InstrsToHoist)
+ Builder.Insert(I);
+
+ // Create switch inst from the hoist block to targets.
+ IRBuilder<> Builder2(HoistBlock);
+ SwitchInst *NewSwitch = Builder.CreateSwitch(
+ SwInst->getCondition(), SwInst->getDefaultDest(), Targets.size());
+
+ for (auto Tuple : Targets)
+ NewSwitch->addCase(Tuple.second, Tuple.first);
+
+ // Update parent switch's target to hoist block for cases under consideration
+ for (auto Tuple : Targets)
+ SwInst->setSuccessor(Tuple.second->getZExtValue(), HoistBlock);
+
+ map<Instruction *, PHINode *> HoistInstrToPHIMap;
+ // for each target BB, create incoming phis, for each of the hoisted values.
+ // For each Target BB update the uses of the old values by uses of phis. Get a
+ // mapping from phis to the orignal instruction and do a simple replace uses.
+ // Let's create HoistToOrigInstrMap and HoistToPhis Map. Making the update
+ // easy.
+ for (auto Tuple : Targets) {
+ BasicBlock *Tgt = Tuple.first;
+ map<Instruction *, vector<Instruction *>> HoistInstrToOldInstrMap =
+ buildHoistToOldInstrsMapForBB(*Tgt, InstrsToHoist);
+ for (Instruction *I : InstrsToHoist) {
+ PHINode *PHI = PHINode::Create(I->getType(), 1, "", &Tgt->front());
+ PHI->addIncoming(I, HoistBlock);
+ vector<Instruction *> &ToReplOld = HoistInstrToOldInstrMap[I];
+ for (Instruction *OI : ToReplOld) {
+ OI->replaceAllUsesWith(PHI);
+ assert(OI->getNumUses() == 0 &&
+ "Attempting to delete instr with uses.");
+ OI->eraseFromParent();
+ }
+ }
+ }
+ assert(NewSwitch != nullptr &&
+ "new switch cannot be null after the transform");
+ return true;
+}
+
+void doSwitchHoist(Function &F, const VNtoInsns &ScalarsMap) {
+ vector<SwitchInst *> SwitchInstrs;
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ if (SwitchInst *SwI = dyn_cast<SwitchInst>(&I))
+ SwitchInstrs.push_back(SwI);
+ }
+ }
+
+ for (SwitchInst *SwI : SwitchInstrs) {
+ SwitchTreeNode *SwitchRoot = BuildSwitchtree(SwI, ScalarsMap);
+ vector<unsigned> HoistCand = ChooseHoistableCandidates(SwitchRoot);
+ if (!HoistCand.empty())
+ performSwitchHoist(SwitchRoot, ScalarsMap, HoistCand);
+ }
+}
+
std::pair<unsigned, unsigned> GVNHoist::hoistExpressions(Function &F) {
InsnInfo II;
LoadInfo LI;
@@ -1196,6 +1448,7 @@ std::pair<unsigned, unsigned> GVNHoist::hoistExpressions(Function &F) {
computeInsertionPoints(CI.getScalarVNTable(), HPL, InsKind::Scalar);
computeInsertionPoints(CI.getLoadVNTable(), HPL, InsKind::Load);
computeInsertionPoints(CI.getStoreVNTable(), HPL, InsKind::Store);
+ doSwitchHoist(F,II.getVNTable());
return hoist(HPL);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/106842
More information about the llvm-commits
mailing list