[llvm] [SPIR-V] Add SPIR-V structurizer (PR #97606)
Michal Paszkowski via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 7 16:35:23 PDT 2024
================
@@ -0,0 +1,649 @@
+//===-- SPIRVStructurizer.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass adds the required OpSelection/OpLoop merge instructions to
+// generate valid SPIR-V.
+// This pass trims convergence intrinsics as those were only useful when
+// modifying the CFG during IR passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include <stack>
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/BreadthFirstIterator.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+void initializeSPIRVStructurizerPass(PassRegistry &);
+}
+
+namespace {
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `MBB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+ MachineBasicBlock *MBB) {
+ if (Node->Entry == MBB->getBasicBlock())
+ return Node;
+
+ for (auto *Child : Node->Children) {
+ const auto *CR = getRegionForHeader(Child, MBB);
+ if (CR != nullptr)
+ return CR;
+ }
+ return nullptr;
+}
+
+// Returns the MachineBasicBlock in `MF` matching `BB`, nullptr otherwise.
+MachineBasicBlock *getMachineBlockFor(MachineFunction &MF, BasicBlock *BB) {
+ for (auto &MBB : MF)
+ if (MBB.getBasicBlock() == BB)
+ return &MBB;
+ return nullptr;
+}
+
+// Gather all the successors of |BB|.
+// This function asserts if the terminator neither a branch, switch or return.
+std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+ std::unordered_set<BasicBlock *> output;
+ auto *T = BB->getTerminator();
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+ output.insert(BI->getSuccessor(0));
+ if (BI->isConditional())
+ output.insert(BI->getSuccessor(1));
+ return output;
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(T)) {
+ output.insert(SI->getDefaultDest());
+ for (auto &Case : SI->cases())
+ output.insert(Case.getCaseSuccessor());
+ return output;
+ }
+
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return output;
+
+ assert(false && "Unhandled terminator type.");
+ return output;
+}
+
+// Returns the single MachineBasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists. MF must be the function CR belongs to.
+MachineBasicBlock *getExitFor(MachineFunction &MF,
+ const ConvergenceRegion *CR) {
+ std::unordered_set<BasicBlock *> ExitTargets;
+ for (BasicBlock *Exit : CR->Exits) {
+ for (BasicBlock *Target : gatherSuccessors(Exit)) {
+ if (CR->Blocks.count(Target) == 0)
+ ExitTargets.insert(Target);
+ }
+ }
+
+ assert(ExitTargets.size() <= 1);
+ if (ExitTargets.size() == 0)
+ return nullptr;
+
+ auto *Exit = *ExitTargets.begin();
+ return getMachineBlockFor(MF, Exit);
+}
+
+// Returns true is |I| is a OpLoopMerge or OpSelectionMerge instruction.
+bool isMergeInstruction(const MachineInstr &I) {
+ return I.getOpcode() == SPIRV::OpLoopMerge ||
+ I.getOpcode() == SPIRV::OpSelectionMerge;
+}
+
+// Returns the first OpLoopMerge/OpSelectionMerge instruction found in |MBB|,
+// nullptr otherwise.
+MachineInstr *getMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (isMergeInstruction(I))
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpLoopMerge instruction found in |MBB|, nullptr otherwise.
+MachineInstr *getLoopMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpLoopMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpSelectionMerge instruction found in |MBB|, nullptr
+// otherwise.
+MachineInstr *getSelectionMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpSelectionMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Traverses the CFG DFS-style starting from the entry point.
+// Calls |op| on each basic block encountered during the traversal.
+void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
+ std::stack<MachineBasicBlock *> ToVisit;
+ SmallPtrSet<MachineBasicBlock *, 8> Seen;
+
+ ToVisit.push(&*MF.begin());
+ Seen.insert(ToVisit.top());
+ while (ToVisit.size() != 0) {
+ MachineBasicBlock *MBB = ToVisit.top();
+ ToVisit.pop();
+
+ op(MBB);
+
+ for (auto Succ : MBB->successors()) {
+ if (Seen.contains(Succ))
+ continue;
+ ToVisit.push(Succ);
+ Seen.insert(Succ);
+ }
+ }
+}
+
+// Returns all basic blocks in |MF| with at least one SelectionMerge/LoopMerge
+// instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getHeaderBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ output.insert(&MBB);
+ }
+ return output;
+}
+
+// Returns all basic blocks in |MF| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getMergeBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ output.insert(MI->getOperand(0).getMBB());
+ }
+ return output;
+}
+
+// Returns all basic blocks in |MF| referenced as continue target by at least 1
+// OpLoopMerge.
+SmallPtrSet<MachineBasicBlock *, 8> getContinueBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr && MI->getOpcode() == SPIRV::OpLoopMerge)
+ output.insert(MI->getOperand(1).getMBB());
+ }
+ return output;
+}
+
+// Returns the block immediatly post-dominating every block in |range| if any,
+// nullptr otherwise.
+MachineBasicBlock *findNearestCommonDominator(
+ const iterator_range<std::vector<MachineBasicBlock *>::iterator> &range,
+ MachinePostDominatorTree &MPDT) {
+ assert(!range.empty());
+ MachineBasicBlock *Dom = *range.begin();
+ for (MachineBasicBlock *Item : range)
+ Dom = MPDT.findNearestCommonDominator(Dom, Item);
+ return Dom;
+}
+
+// Finds the first merge instruction in |MBB| and store it in |MI|.
+// If it defines a merge target, sets |Merge| to the merge target.
+// If it defines a continue target, sets |Continue| to the continue target.
+// Returns true if such merge instruction was found, false otherwise.
+bool getMergeInstructionTargets(MachineBasicBlock *MBB, MachineInstr **MI,
+ MachineBasicBlock **Merge,
+ MachineBasicBlock **Continue) {
+ *Merge = nullptr;
+ *Continue = nullptr;
+
+ *MI = getMergeInstruction(*MBB);
+ if (*MI == nullptr)
+ return false;
+
+ *Merge = (*MI)->getOperand(0).getMBB();
+ *Continue = (*MI)->getOpcode() == SPIRV::OpLoopMerge
+ ? (*MI)->getOperand(1).getMBB()
+ : nullptr;
+ return true;
+}
+
+} // anonymous namespace
+
+class SPIRVStructurizer : public MachineFunctionPass {
+public:
+ static char ID;
+
+ SPIRVStructurizer() : MachineFunctionPass(ID) {
+ initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry());
+ };
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ MachineFunctionPass::getAnalysisUsage(AU);
+ AU.addRequired<MachineLoopInfo>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+ }
+
+ // Creates a MachineBasicBlock ending with an OpReturn just after
+ // |Predecessor|. This function does not add any branch to |Predecessor|, but
+ // adds the new block to its successors.
+ MachineBasicBlock *createReturnBlock(MachineFunction &MF,
+ MachineBasicBlock &Predecessor) {
+ MachineBasicBlock *MBB =
+ MF.CreateMachineBasicBlock(Predecessor.getBasicBlock());
+ MF.push_back(MBB);
+ MBB->moveAfter(&Predecessor);
+ // This code doesn't add a branch instruction to this new return block.
+ // The caller will have to handle that.
+ Predecessor.addSuccessorWithoutProb(MBB);
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(*MBB, MBB->end());
+ MIRBuilder.buildInstr(SPIRV::OpReturn);
+
+ return MBB;
+ }
+
+ // Replace switches with a single target with an unconditional branch.
+ bool replaceEmptySwitchWithBranch(MachineFunction &MF) {
+ bool modified = false;
+ for (MachineBasicBlock &MBB : MF) {
+ MachineInstr *I = &*MBB.rbegin();
+ GIntrinsic *II = dyn_cast<GIntrinsic>(I);
+ if (!II || II->getIntrinsicID() != Intrinsic::spv_switch ||
+ II->getNumOperands() > 3)
+ continue;
+
+ modified = true;
+ assert(II->getOperand(2).isMBB());
+ MachineBasicBlock *Target = II->getOperand(2).getMBB();
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(MBB, MBB.end());
+ MIRBuilder.buildBr(*Target);
+ MBB.erase(I);
+ }
+
+ return modified;
+ }
+
+ // Traverse each loop, and adds an OpLoopMerge instruction to its header
+ // that respect the convergence region node it belongs to.
+ // The Continue target is the only back-edge in that loop.
+ // The merge target is the only exiting node of the convergence region.
+ bool addMergeForLoops(MachineFunction &MF) {
+ auto &TII = *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ auto &TRI = *MF.getSubtarget<SPIRVSubtarget>().getRegisterInfo();
+ auto &RBI = *MF.getSubtarget<SPIRVSubtarget>().getRegBankInfo();
+
+ const auto &MLI = getAnalysis<MachineLoopInfo>();
+ const auto *TopLevelRegion =
+ getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+ .getRegionInfo()
+ .getTopLevelRegion();
+
+ bool modified = false;
+ for (auto &MBB : MF) {
+ // Not a loop header. Ignoring for now.
+ if (!MLI.isLoopHeader(&MBB))
+ continue;
+ auto *L = MLI.getLoopFor(&MBB);
+
+ // This loop header is not the entrance of a convergence region. Ignoring
+ // this block.
+ auto *CR = getRegionForHeader(TopLevelRegion, &MBB);
+ if (CR == nullptr)
+ continue;
+
+ auto *Merge = getExitFor(MF, CR);
+ // This is a special case:
+ // We are indeed in a loop, but there are no exits (infinite loop).
+ // This means the actual branch is unconditional, hence won't require any
+ // OpLoopMerge.
+ if (Merge == nullptr) {
+ Merge = createReturnBlock(MF, MBB);
+ }
+
+ auto *Continue = L->getLoopLatch();
+
+ // Conditional branch are built using a fallthrough if false + BR.
+ // So the last instruction is not always the first branch.
+ auto *I = &*MBB.getFirstTerminator();
+ BuildMI(MBB, I, I->getDebugLoc(), TII.get(SPIRV::OpLoopMerge))
+ .addMBB(Merge)
+ .addMBB(Continue)
+ .addImm(SPIRV::SelectionControl::None)
+ .constrainAllUses(TII, TRI, RBI);
+ modified = true;
+ }
+
+ return modified;
+ }
+
+ // Add an OpSelectionMerge to each node with an out-degree of 2 or more.
+ bool addMergeForConditionalBranches(MachineFunction &MF) {
+ MachinePostDominatorTree MPDT(MF);
+
+ auto &TII = *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ auto &TRI = *MF.getSubtarget<SPIRVSubtarget>().getRegisterInfo();
+ auto &RBI = *MF.getSubtarget<SPIRVSubtarget>().getRegBankInfo();
+
+ auto MergeBlocks = getMergeBlocks(MF);
+ auto ContinueBlocks = getContinueBlocks(MF);
+
+ for (auto &MBB : MF) {
+ if (MBB.succ_size() <= 1)
+ continue;
+
+ // Block already has an OpSelectionMerge instruction. Ignoring.
+ if (getSelectionMergeInstruction(MBB)) {
+ continue;
+ }
+
+ assert(MBB.succ_size() >= 2);
+ size_t NonStructurizedTargets = 0;
+ for (MachineBasicBlock *Successor : MBB.successors()) {
+ if (!MergeBlocks.contains(Successor) &&
+ !ContinueBlocks.contains(Successor))
+ NonStructurizedTargets += 1;
+ }
+
+ if (NonStructurizedTargets <= 1)
+ continue;
+
+ MachineBasicBlock *Merge =
+ findNearestCommonDominator(MBB.successors(), MPDT);
+ if (!Merge) {
+ // TODO: we should check which block is not another construct merge
+ // block, and select this one. For now, tests passes with this strategy,
+ // but once we find a test case, we should fix that.
+ Merge = *MBB.succ_begin();
+ }
+
+ assert(Merge);
+ auto *II = MBB.getFirstTerminator() == MBB.end()
+ ? &*MBB.rbegin()
+ : &*MBB.getFirstTerminator();
+ BuildMI(MBB, II, II->getDebugLoc(), TII.get(SPIRV::OpSelectionMerge))
+ .addMBB(Merge)
+ .addImm(SPIRV::SelectionControl::None)
+ .constrainAllUses(TII, TRI, RBI);
+ }
+
+ return false;
+ }
+
+ // Cut |Block| just after the first OpLoopMerge/OpSelectionMerge instruction.
+ // The newly created block lies just after |Block|, and |Block| branches
+ // unconditionally to this new block. Returns the newly created block.
+ MachineBasicBlock *splitHeaderBlock(MachineFunction &MF,
+ MachineBasicBlock &Block) {
+ auto FirstMerge = Block.begin();
+ while (!isMergeInstruction(*FirstMerge)) {
+ FirstMerge++;
+ }
+
+ MachineBasicBlock *NewBlock = Block.splitAt(*FirstMerge);
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(Block, Block.end());
+ MIRBuilder.buildBr(*NewBlock);
+
+ return NewBlock;
+ }
+
+ // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
+ // instructions so each basic block contains only a single merge instruction.
+ bool splitBlocksWithMultipleHeaders(MachineFunction &MF) {
+ bool modified = false;
+ for (auto &MBB : MF) {
+ MachineInstr *SelectionMerge = getSelectionMergeInstruction(MBB);
+ MachineInstr *LoopMerge = getLoopMergeInstruction(MBB);
+ if (!SelectionMerge || !LoopMerge) {
+ continue;
+ }
+
+ splitHeaderBlock(MF, MBB);
+ modified = true;
+ }
+ return modified;
+ }
+
+ // Splits the basic block |OldMerge| in two.
+ // The newly created block will become the predecessor of |OldMerge|.
+ // |HeaderBlock| becomes the only block using |OldMerge| as merge target.
+ // Each other Merge instruction having |OldMerge| as target will have the
+ // newly created block as target.
+ MachineBasicBlock *splitMergeBlock(MachineDominatorTree &MDT,
+ MachineFunction &MF,
+ MachineBasicBlock &OldMerge,
+ MachineBasicBlock &HeaderBlock) {
+
+ std::vector<MachineBasicBlock *> toUpdate;
+ for (MachineBasicBlock *Predecessor : OldMerge.predecessors())
+ toUpdate.push_back(Predecessor);
+
+ MachineBasicBlock *NewMerge =
+ MF.CreateMachineBasicBlock(OldMerge.getBasicBlock());
+ MF.push_back(NewMerge);
+ NewMerge->moveBefore(&OldMerge);
+ NewMerge->addSuccessorWithoutProb(&OldMerge);
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(*NewMerge, NewMerge->end());
+ MIRBuilder.buildBr(OldMerge);
+
+ for (MachineBasicBlock *Predecessor : toUpdate) {
+ if (!MDT.dominates(&HeaderBlock, Predecessor))
+ continue;
+
+ OldMerge.replacePhiUsesWith(Predecessor, NewMerge);
+ Predecessor->removeSuccessor(&OldMerge);
+ Predecessor->addSuccessorWithoutProb(NewMerge);
+ for (auto &I : *Predecessor) {
+ for (auto &O : I.operands()) {
+ if (O.isMBB() && O.getMBB() == &OldMerge)
+ O.setMBB(NewMerge);
+ }
+ }
+ }
+
+ auto *MI = getMergeInstruction(HeaderBlock);
+ assert(MI);
+ MI->getOperand(0).setMBB(NewMerge);
+
+ return NewMerge;
+ }
+
+ // Modifies the CFG to make sure each merge block is the target of a single
+ // header.
+ bool splitMergeBlocks(MachineFunction &MF) {
+ MachineDominatorTree MDT(MF);
+
+ // Determine all the blocks we need to analyse.
+ auto HeaderBlocks = getHeaderBlocks(MF);
+ // Visit the CFG DFS-style to process header blocks.
+ std::vector<MachineBasicBlock *> ToProcess;
+ visit(MF, [&ToProcess, &HeaderBlocks](MachineBasicBlock *MBB) {
+ if (HeaderBlocks.count(MBB) != 0)
+ ToProcess.push_back(MBB);
+ });
+
+ // Maps each merge-block to its associated header block.
+ std::unordered_map<MachineBasicBlock *, MachineBasicBlock *> MergeToHeader;
+ bool modified = false;
----------------
michalpaszkowski wrote:
Please start [variable names with upper case letter](https://llvm.org/docs/CodingStandards.html#id44).
https://github.com/llvm/llvm-project/pull/97606
More information about the llvm-commits
mailing list