[llvm] [SPIR-V] Add SPIR-V structurizer (PR #97606)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 15 07:29:00 PDT 2024
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/97606 at github.com>
================
@@ -0,0 +1,643 @@
+//===-- SPIRVStructurizer.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass adds the required OpSelection/OpLoop merge instructions to
+// generate valid SPIR-V.
+// This pass trims convergence intrinsics as those were only useful when
+// modifying the CFG during IR passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include <stack>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+void initializeSPIRVStructurizerPass(PassRegistry &);
+}
+
+namespace {
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `MBB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+ MachineBasicBlock *MBB) {
+ if (Node->Entry == MBB->getBasicBlock())
+ return Node;
+
+ for (auto *Child : Node->Children) {
+ const auto *CR = getRegionForHeader(Child, MBB);
+ if (CR != nullptr)
+ return CR;
+ }
+ return nullptr;
+}
+
+// Returns the MachineBasicBlock in `MF` matching `BB`, nullptr otherwise.
+MachineBasicBlock *getMachineBlockFor(MachineFunction &MF, BasicBlock *BB) {
+ for (auto &MBB : MF)
+ if (MBB.getBasicBlock() == BB)
+ return &MBB;
+ return nullptr;
+}
+
+// Gather all the successors of |BB|.
+// This function asserts if the terminator neither a branch, switch or return.
+std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+ std::unordered_set<BasicBlock *> Output;
+ auto *T = BB->getTerminator();
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+ Output.insert(BI->getSuccessor(0));
+ if (BI->isConditional())
+ Output.insert(BI->getSuccessor(1));
+ return Output;
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(T)) {
+ Output.insert(SI->getDefaultDest());
+ for (auto &Case : SI->cases())
+ Output.insert(Case.getCaseSuccessor());
+ return Output;
+ }
+
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return Output;
+
+ assert(false && "Unhandled terminator type.");
+ return Output;
+}
+
+// Returns the single MachineBasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists. MF must be the function CR belongs to.
+MachineBasicBlock *getExitFor(MachineFunction &MF,
+ const ConvergenceRegion *CR) {
+ std::unordered_set<BasicBlock *> ExitTargets;
+ for (BasicBlock *Exit : CR->Exits) {
+ for (BasicBlock *Target : gatherSuccessors(Exit)) {
+ if (CR->Blocks.count(Target) == 0)
+ ExitTargets.insert(Target);
+ }
+ }
+
+ assert(ExitTargets.size() <= 1);
+ if (ExitTargets.size() == 0)
+ return nullptr;
+
+ auto *Exit = *ExitTargets.begin();
+ return getMachineBlockFor(MF, Exit);
+}
+
+// Returns true is |I| is a OpLoopMerge or OpSelectionMerge instruction.
+bool isMergeInstruction(const MachineInstr &I) {
+ return I.getOpcode() == SPIRV::OpLoopMerge ||
+ I.getOpcode() == SPIRV::OpSelectionMerge;
+}
+
+// Returns the first OpLoopMerge/OpSelectionMerge instruction found in |MBB|,
+// nullptr otherwise.
+MachineInstr *getMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (isMergeInstruction(I))
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpLoopMerge instruction found in |MBB|, nullptr otherwise.
+MachineInstr *getLoopMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpLoopMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Returns the first OpSelectionMerge instruction found in |MBB|, nullptr
+// otherwise.
+MachineInstr *getSelectionMergeInstruction(MachineBasicBlock &MBB) {
+ for (auto &I : MBB) {
+ if (I.getOpcode() == SPIRV::OpSelectionMerge)
+ return &I;
+ }
+ return nullptr;
+}
+
+// Traverses the CFG DFS-style starting from the entry point.
+// Calls |op| on each basic block encountered during the traversal.
+void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
+ std::stack<MachineBasicBlock *> ToVisit;
+ SmallPtrSet<MachineBasicBlock *, 8> Seen;
+
+ ToVisit.push(&*MF.begin());
+ Seen.insert(ToVisit.top());
+ while (ToVisit.size() != 0) {
+ MachineBasicBlock *MBB = ToVisit.top();
+ ToVisit.pop();
+
+ op(MBB);
+
+ for (auto Succ : MBB->successors()) {
+ if (Seen.contains(Succ))
+ continue;
+ ToVisit.push(Succ);
+ Seen.insert(Succ);
+ }
+ }
+}
+
+// Returns all basic blocks in |MF| with at least one SelectionMerge/LoopMerge
+// instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getHeaderBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> Output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ Output.insert(&MBB);
+ }
+ return Output;
+}
+
+// Returns all basic blocks in |MF| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getMergeBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> Output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr)
+ Output.insert(MI->getOperand(0).getMBB());
+ }
+ return Output;
+}
+
+// Returns all basic blocks in |MF| referenced as continue target by at least 1
+// OpLoopMerge.
+SmallPtrSet<MachineBasicBlock *, 8> getContinueBlocks(MachineFunction &MF) {
+ SmallPtrSet<MachineBasicBlock *, 8> Output;
+ for (MachineBasicBlock &MBB : MF) {
+ auto *MI = getMergeInstruction(MBB);
+ if (MI != nullptr && MI->getOpcode() == SPIRV::OpLoopMerge)
+ Output.insert(MI->getOperand(1).getMBB());
+ }
+ return Output;
+}
+
+// Returns the block immediatly post-dominating every block in |Range| if any,
+// nullptr otherwise.
+MachineBasicBlock *findNearestCommonDominator(
----------------
s-perron wrote:
The comment says you are looking for a node that post dominates the range, but the function names just looks for a dominator. This is confusing because they are different things.
```suggestion
MachineBasicBlock *findNearestCommonPostDominator(
```
https://github.com/llvm/llvm-project/pull/97606
More information about the llvm-commits
mailing list