[llvm] [SPIR-V] Add SPIR-V structurizer (PR #97606)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 3 09:45:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Nathan Gauër (Keenuts)
<details>
<summary>Changes</summary>
This commit adds an initial SPIR-V structurizer.
It leverages the previously merged passes, and the convergence region analysis to determine the correct merge and continue blocks for SPIR-V.
The first part replaces switches with a single edge with a simple branch.
Then, we add OpLoopMerge instruction to each loop encountered. Then, we add OpSelectionMerge for each conditional branch which has no OpLoopMerge instruction.
Finally, we fixup the merge instructions:
- we split blocks with multiple headers into 2 blocks.
- we split blocks that are a merge blocks for 2 or more constructs: SPIR-V spec disallow a merge block to be shared by 2 loop/switch/condition construct.
- we split merge & continue blocks: SPIR-V spec disallow a basic block to be both a continue block, and a merge block.
- we remove superfluous headers: when a header doesn't bring more info than the parent on the divergence state, it must be removed.
- We sort blocks depending on the dominator tree order: SPIR-V spec requires blocks to be sorted in a specific way.
As is, this code seems to pass most DXC structurization tests, except the ones relying on unrelated features this backend doesn't yet support like i64 switches, boolean types.
One known case fails, but it is because the MergeExit pass doesn't supports switches yet (there is a FIXME).
This PR is already big-enough as-is, so I think we should get this in first, and then add support for those switches.
---
Patch is 77.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97606.diff
19 Files Affected:
- (modified) llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp (+2-2)
- (modified) llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h (+3)
- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1)
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1)
- (modified) llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (+12-8)
- (added) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+649)
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+5-2)
- (removed) llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll (-80)
- (added) llvm/test/CodeGen/SPIRV/structurizer/condition-linear.ll (+127)
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-break.ll (+88)
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-continue.ll (+124)
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-nested.ll (+102)
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-plain.ll (+98)
- (added) llvm/test/CodeGen/SPIRV/structurizer/logical-or.ll (+158)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+21-18)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+30-28)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+36-37)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll (+6-4)
- (added) llvm/test/CodeGen/SPIRV/structurizer/return-early.ll (+130)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index 25e285e35f933..d77900001ea4b 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -230,7 +230,8 @@ class ConvergenceRegionAnalyzer {
auto *Terminator = From->getTerminator();
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
auto *To = Terminator->getSuccessor(i);
- if (isBackEdge(From, To))
+ // Ignore back edges and self edges.
+ if (From == To || isBackEdge(From, To))
continue;
auto ChildSet = findPathsToMatch(LI, To, isMatch);
@@ -276,7 +277,6 @@ class ConvergenceRegionAnalyzer {
while (ToProcess.size() != 0) {
auto *L = ToProcess.front();
ToProcess.pop();
- assert(L->isLoopSimplifyForm());
auto CT = getConvergenceToken(L->getHeader());
SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
index f9e30e4effa1d..e435c88c919c9 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
@@ -130,6 +130,9 @@ class ConvergenceRegionInfo {
}
const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
+ ConvergenceRegion *getWritableTopLevelRegion() const {
+ return TopLevelRegion;
+ }
};
} // namespace SPIRV
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 14647e92f5d08..425da84b62db1 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -31,6 +31,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVMCInstLower.cpp
SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
+ SPIRVStructurizer.cpp
SPIRVPreLegalizer.cpp
SPIRVPostLegalizer.cpp
SPIRVPrepareFunctions.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index e597a1dc8dc06..d6597db8bc0e6 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
class RegisterBankInfo;
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+FunctionPass *createSPIRVStructurizerPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 0747dd1bbaf40..9930d067173df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -133,7 +133,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
// Run the pass on the given convergence region, ignoring the sub-regions.
// Returns true if the CFG changed, false otherwise.
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
- const SPIRV::ConvergenceRegion *CR) {
+ SPIRV::ConvergenceRegion *CR) {
// Gather all the exit targets for this region.
SmallPtrSet<BasicBlock *, 4> ExitTargets;
for (BasicBlock *Exit : CR->Exits) {
@@ -198,14 +198,19 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
for (auto Exit : CR->Exits)
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
+ CR = CR->Parent;
+ while (CR) {
+ CR->Blocks.insert(NewExitTarget);
+ CR = CR->Parent;
+ }
+
return true;
}
/// Run the pass on the given convergence region and sub-regions (DFS).
/// Returns true if a region/sub-region was modified, false otherwise.
/// This returns as soon as one region/sub-region has been modified.
- bool runOnConvergenceRegion(LoopInfo &LI,
- const SPIRV::ConvergenceRegion *CR) {
+ bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
for (auto *Child : CR->Children)
if (runOnConvergenceRegion(LI, Child))
return true;
@@ -235,10 +240,10 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
virtual bool runOnFunction(Function &F) override {
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- const auto *TopLevelRegion =
+ auto *TopLevelRegion =
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
.getRegionInfo()
- .getTopLevelRegion();
+ .getWritableTopLevelRegion();
// FIXME: very inefficient method: each time a region is modified, we bubble
// back up, and recompute the whole convergence region tree. Once the
@@ -246,9 +251,6 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
// to be efficient instead of simple.
bool modified = false;
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
- TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
- .getRegionInfo()
- .getTopLevelRegion();
modified = true;
}
@@ -262,6 +264,8 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+
+ AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
};
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
new file mode 100644
index 0000000000000..ac20f16905022
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -0,0 +1,649 @@
+//===-- SPIRVStructurizer.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass adds the required OpSelection/OpLoop merge instructions to
+// generate valid SPIR-V.
+// This pass trims convergence intrinsics as those were only useful when
+// modifying the CFG during IR passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include <stack>
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/BreadthFirstIterator.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+void initializeSPIRVStructurizerPass(PassRegistry &);
+}
+
+namespace {
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `MBB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+ MachineBasicBlock *MBB) {
+ if (Node->Entry == MBB->getBasicBlock())
+ return Node;
+
+ for (auto *Child : Node->Children) {
+ const auto *CR = getRegionForHeader(Child, MBB);
+ if (CR != nullptr)
+ return CR;
+ }
+ return nullptr;
+}
+
+// Returns the MachineBasicBlock in `MF` matching `BB`, nullptr otherwise.
+MachineBasicBlock *getMachineBlockFor(MachineFunction &MF, BasicBlock *BB) {
+ for (auto &MBB : MF)
+ if (MBB.getBasicBlock() == BB)
+ return &MBB;
+ return nullptr;
+}
+
+// Gather all the successors of |BB|.
+// This function asserts if the terminator neither a branch, switch or return.
+std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+ std::unordered_set<BasicBlock *> output;
+ auto *T = BB->getTerminator();
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+ output.insert(BI->getSuccessor(0));
+ if (BI->isConditional())
+ output.insert(BI->getSuccessor(1));
+ return output;
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(T)) {
+ output.insert(SI->getDefaultDest());
+ for (auto &Case : SI->cases())
+ output.insert(Case.getCaseSuccessor());
+ return output;
+ }
+
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return output;
+
+ assert(false && "Unhandled terminator type.");
+ return output;
+}
+
+// Returns the single MachineBasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists. MF must be the function CR belongs to.
+MachineBasicBlock *getExitFor(MachineFunction &MF,
+ const ConvergenceRegion *CR) {
+ std::unordered_set<BasicBlock *> ExitTargets;
+ for (BasicBlock *Exit : CR->Exits) {
+ for (BasicBlock *Target : gatherSuccessors(Exit)) {
+ if (CR->Blocks.count(Target) == 0)
+ ExitTargets.insert(Target);
+ }
+ }
+
+ assert(ExitTargets.size() <= 1);
+ if (ExitTargets.size() == 0)
+ return nullptr;
+
+ auto *Exit = *ExitTargets.begin();
+ return getMachineBlockFor(MF, Exit);
+}
+
+// Returns true is |I| is a OpLoopMerge or OpSelectionMerge instruction.
+bool isMergeInstruction(const MachineInstr &I) {
+ return I.getOpcode() == SPIRV::OpLoopMerge ||
+ I.getOpcode() == SPIRV::OpSelectionMerge;
+}
+
+// Returns the first OpLoopMerge/OpSelectionMerge instruction found in |MBB|,
+// nullptr otherwise.
+MachineInstr *getMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (isMergeInstruction(I))
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpLoopMerge instruction found in |MBB|, nullptr otherwise.
+MachineInstr *getLoopMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpLoopMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpSelectionMerge instruction found in |MBB|, nullptr
+// otherwise.
+MachineInstr *getSelectionMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpSelectionMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Traverses the CFG DFS-style starting from the entry point.
+// Calls |op| on each basic block encountered during the traversal.
+void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
+ std::stack<MachineBasicBlock *> ToVisit;
+ SmallPtrSet<MachineBasicBlock *, 8> Seen;
+
+ ToVisit.push(&*MF.begin());
+ Seen.insert(ToVisit.top());
+ while (ToVisit.size() != 0) {
+ MachineBasicBlock *MBB = ToVisit.top();
+ ToVisit.pop();
+
+ op(MBB);
+
+ for (auto Succ : MBB->successors()) {
+ if (Seen.contains(Succ))
+ continue;
+ ToVisit.push(Succ);
+ Seen.insert(Succ);
+ }
+ }
+}
+
+// Returns all basic blocks in |MF| with at least one SelectionMerge/LoopMerge
+// instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getHeaderBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ output.insert(&MBB);
+ }
+ return output;
+}
+
+// Returns all basic blocks in |MF| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getMergeBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ output.insert(MI->getOperand(0).getMBB());
+ }
+ return output;
+}
+
+// Returns all basic blocks in |MF| referenced as continue target by at least 1
+// OpLoopMerge.
+SmallPtrSet<MachineBasicBlock *, 8> getContinueBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr && MI->getOpcode() == SPIRV::OpLoopMerge)
+ output.insert(MI->getOperand(1).getMBB());
+ }
+ return output;
+}
+
+// Returns the block immediatly post-dominating every block in |range| if any,
+// nullptr otherwise.
+MachineBasicBlock *findNearestCommonDominator(
+ const iterator_range<std::vector<MachineBasicBlock *>::iterator> &range,
+ MachinePostDominatorTree &MPDT) {
+ assert(!range.empty());
+ MachineBasicBlock *Dom = *range.begin();
+ for (MachineBasicBlock *Item : range)
+ Dom = MPDT.findNearestCommonDominator(Dom, Item);
+ return Dom;
+}
+
+// Finds the first merge instruction in |MBB| and store it in |MI|.
+// If it defines a merge target, sets |Merge| to the merge target.
+// If it defines a continue target, sets |Continue| to the continue target.
+// Returns true if such merge instruction was found, false otherwise.
+bool getMergeInstructionTargets(MachineBasicBlock *MBB, MachineInstr **MI,
+ MachineBasicBlock **Merge,
+ MachineBasicBlock **Continue) {
+ *Merge = nullptr;
+ *Continue = nullptr;
+
+ *MI = getMergeInstruction(*MBB);
+ if (*MI == nullptr)
+ return false;
+
+ *Merge = (*MI)->getOperand(0).getMBB();
+ *Continue = (*MI)->getOpcode() == SPIRV::OpLoopMerge
+ ? (*MI)->getOperand(1).getMBB()
+ : nullptr;
+ return true;
+}
+
+} // anonymous namespace
+
+class SPIRVStructurizer : public MachineFunctionPass {
+public:
+ static char ID;
+
+ SPIRVStructurizer() : MachineFunctionPass(ID) {
+ initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry());
+ };
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ MachineFunctionPass::getAnalysisUsage(AU);
+ AU.addRequired<MachineLoopInfo>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+ }
+
+ // Creates a MachineBasicBlock ending with an OpReturn just after
+ // |Predecessor|. This function does not add any branch to |Predecessor|, but
+ // adds the new block to its successors.
+ MachineBasicBlock *createReturnBlock(MachineFunction &MF,
+ MachineBasicBlock &Predecessor) {
+ MachineBasicBlock *MBB =
+ MF.CreateMachineBasicBlock(Predecessor.getBasicBlock());
+ MF.push_back(MBB);
+ MBB->moveAfter(&Predecessor);
+ // This code doesn't add a branch instruction to this new return block.
+ // The caller will have to handle that.
+ Predecessor.addSuccessorWithoutProb(MBB);
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(*MBB, MBB->end());
+ MIRBuilder.buildInstr(SPIRV::OpReturn);
+
+ return MBB;
+ }
+
+ // Replace switches with a single target with an unconditional branch.
+ bool replaceEmptySwitchWithBranch(MachineFunction &MF) {
+ bool modified = false;
+ for (MachineBasicBlock &MBB : MF) {
+ MachineInstr *I = &*MBB.rbegin();
+ GIntrinsic *II = dyn_cast<GIntrinsic>(I);
+ if (!II || II->getIntrinsicID() != Intrinsic::spv_switch ||
+ II->getNumOperands() > 3)
+ continue;
+
+ modified = true;
+ assert(II->getOperand(2).isMBB());
+ MachineBasicBlock *Target = II->getOperand(2).getMBB();
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(MBB, MBB.end());
+ MIRBuilder.buildBr(*Target);
+ MBB.erase(I);
+ }
+
+ return modified;
+ }
+
+ // Traverse each loop, and adds an OpLoopMerge instruction to its header
+ // that respect the convergence region node it belongs to.
+ // The Continue target is the only back-edge in that loop.
+ // The merge target is the only exiting node of the convergence region.
+ bool addMergeForLoops(MachineFunction &MF) {
+ auto &TII = *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ auto &TRI = *MF.getSubtarget<SPIRVSubtarget>().getRegisterInfo();
+ auto &RBI = *MF.getSubtarget<SPIRVSubtarget>().getRegBankInfo();
+
+ const auto &MLI = getAnalysis<MachineLoopInfo>();
+ const auto *TopLevelRegion =
+ getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+ .getRegionInfo()
+ .getTopLevelRegion();
+
+ bool modified = false;
+ for (auto &MBB : MF) {
+ // Not a loop header. Ignoring for now.
+ if (!MLI.isLoopHeader(&MBB))
+ continue;
+ auto *L = MLI.getLoopFor(&MBB);
+
+ // This loop header is not the entrance of a convergence region. Ignoring
+ // this block.
+ auto *CR = getRegionForHeader(TopLevelRegion, &MBB);
+ if (CR == nullptr)
+ continue;
+
+ auto *Merge = getExitFor(MF, CR);
+ // This is a special case:
+ // We are indeed in a loop, but there are no exits (infinite loop).
+ // This means the actual branch is unconditional, hence won't require any
+ // OpLoopMerge.
+ if (Merge == nullptr) {
+ Merge = createReturnBlock(MF, MBB);
+ }
+
+ auto *Continue = L->getLoopLatch();
+
+ // Conditional branch are built using a fallthrough if false + BR.
+ // So the last instruction is not always the first branch.
+ auto *I = &*MBB.getFirstTerminator();
+ BuildMI(MBB, I, I->getDebugLoc(), TII.get(SPIRV::OpLoopMerge))
+ .addMBB(Merge)
+ .addMBB(Continue)
+ .addImm(SPIRV::SelectionControl::None)
+ .constrainAllUses(TII, TRI, RBI);
+ modified = true;
+ }
+
+ return modified;
+ }
+
+ // Add an OpSelectionMerge to each node with an out-degree of 2 or more.
+ bool addMergeForConditionalBranches(MachineFunction &MF) {
+ MachinePostDominatorTree MPDT(MF);
+
+ auto &TII = *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ auto &TRI = *MF.getSubtarget<SPIRVSubtarget>().getRegisterInfo();
+ auto &RBI = *MF.getSubtarget<SPIRVSubtarget>().getRegBankInfo();
+
+ auto MergeBlocks = getMergeBlocks(MF);
+ auto ContinueBlocks = getContinueBlocks(MF);
+
+ for (auto &MBB : MF) {
+ if (MBB.succ_size() <= 1)
+ continue;
+
+ // Block already has an OpSelectionMerge instruction. Ignoring.
+ if (getSelectionMergeInstruction(MBB)) {
+ continue;
+ }
+
+ assert(MBB.succ_size() >= 2);
+ size_t NonStructurizedTargets = 0;
+ for (MachineBasicBlock *Successor : MBB.successors()) {
+ if (!MergeBlocks.contains(Successor) &&
+ !ContinueBlocks.contains(Successor))
+ NonStructurizedTargets += 1;
+ }
+
+ if (NonStructurizedTargets <= 1)
+ continue;
+
+ MachineBasicBlock *Merge =
+ findNearestCommonDominator(MBB.successors(), MPDT);
+ if (!Merge) {
+ // TODO: we should check which block is not another construct merge
+ // block, and select this one. For now, tests passes with this strategy,
+ // but once we find a test case, we should fix that.
+ Merge = *MBB.succ_begin();
+ }
+
+ assert(Merge);
+ auto *II = MBB.getFirstTerminator() == MBB.end()
+ ? &*MBB.rbegin()
+ : &*MBB.getFirstTerminator();
+ BuildMI(MBB, II, II->getDebugLoc(), TII.get(SPIRV::OpSelectionMerge))
+ .addMBB(Merge)
+ .addImm(SPIRV::SelectionControl::None)
+ .constrainAllUses(TII, TRI, RBI);
+ }
+
+ return false;
+ }
+
+ // Cut |Block| just after the first OpLoopMerge/OpSelectionMerge instruction.
+ // The newly created block lies just after |Block|, and |Block| branches
+ // unconditionally to this new block. Returns the newly created block.
+ MachineBasicBlock *splitHeaderBlock(MachineFunction &MF,
+ MachineBasicBlock &Block) {
+ auto FirstMerge = Block.begin();
+ while (!isMergeInstruction(*FirstMerge)) {
+ FirstMerge++;
+ }
+
+ MachineBasicBlock *NewBlock = Block.splitAt(*FirstMerge);
+
+ MachineIRBuilder MIRBuilder(MF);
+ MIRBuilder.setInsertPt(Block, Block.end());
+ MIRBuilder.buildBr(*NewBlock);
+
+ retur...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/97606
More information about the llvm-commits
mailing list