[llvm] [SPIR-V] Add SPIR-V structurizer (PR #97606)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 15 07:28:59 PDT 2024


Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/97606 at github.com>


================
@@ -0,0 +1,643 @@
+//===-- SPIRVStructurizer.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass adds the required OpSelection/OpLoop merge instructions to
+// generate valid SPIR-V.
+// This pass trims convergence intrinsics as those were only useful when
+// modifying the CFG during IR passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include <stack>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+void initializeSPIRVStructurizerPass(PassRegistry &);
+}
+
+namespace {
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `MBB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+                                            MachineBasicBlock *MBB) {
+  if (Node->Entry == MBB->getBasicBlock())
+    return Node;
+
+  for (auto *Child : Node->Children) {
+    const auto *CR = getRegionForHeader(Child, MBB);
+    if (CR != nullptr)
+      return CR;
+  }
+  return nullptr;
+}
+
+// Returns the MachineBasicBlock in `MF` matching `BB`, nullptr otherwise.
+MachineBasicBlock *getMachineBlockFor(MachineFunction &MF, BasicBlock *BB) {
+  for (auto &MBB : MF)
+    if (MBB.getBasicBlock() == BB)
+      return &MBB;
+  return nullptr;
+}
+
+// Gather all the successors of |BB|.
+// This function asserts if the terminator neither a branch, switch or return.
+std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+  std::unordered_set<BasicBlock *> Output;
+  auto *T = BB->getTerminator();
+
+  if (auto *BI = dyn_cast<BranchInst>(T)) {
+    Output.insert(BI->getSuccessor(0));
+    if (BI->isConditional())
+      Output.insert(BI->getSuccessor(1));
+    return Output;
+  }
+
+  if (auto *SI = dyn_cast<SwitchInst>(T)) {
+    Output.insert(SI->getDefaultDest());
+    for (auto &Case : SI->cases())
+      Output.insert(Case.getCaseSuccessor());
+    return Output;
+  }
+
+  if (auto *RI = dyn_cast<ReturnInst>(T))
+    return Output;
+
+  assert(false && "Unhandled terminator type.");
+  return Output;
+}
+
+// Returns the single MachineBasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists. MF must be the function CR belongs to.
+MachineBasicBlock *getExitFor(MachineFunction &MF,
+                              const ConvergenceRegion *CR) {
+  std::unordered_set<BasicBlock *> ExitTargets;
+  for (BasicBlock *Exit : CR->Exits) {
+    for (BasicBlock *Target : gatherSuccessors(Exit)) {
+      if (CR->Blocks.count(Target) == 0)
+        ExitTargets.insert(Target);
+    }
+  }
+
+  assert(ExitTargets.size() <= 1);
+  if (ExitTargets.size() == 0)
+    return nullptr;
+
+  auto *Exit = *ExitTargets.begin();
+  return getMachineBlockFor(MF, Exit);
+}
+
+// Returns true is |I| is a OpLoopMerge or OpSelectionMerge instruction.
+bool isMergeInstruction(const MachineInstr &I) {
+  return I.getOpcode() == SPIRV::OpLoopMerge ||
+         I.getOpcode() == SPIRV::OpSelectionMerge;
+}
+
+// Returns the first OpLoopMerge/OpSelectionMerge instruction found in |MBB|,
+// nullptr otherwise.
+MachineInstr *getMergeInstruction(MachineBasicBlock &MBB) {
+  for (auto &I : MBB) {
+    if (isMergeInstruction(I))
+      return &I;
+  }
+  return nullptr;
+}
+
+// Returns the first OpLoopMerge instruction found in |MBB|, nullptr otherwise.
+MachineInstr *getLoopMergeInstruction(MachineBasicBlock &MBB) {
+  for (auto &I : MBB) {
+    if (I.getOpcode() == SPIRV::OpLoopMerge)
+      return &I;
+  }
+  return nullptr;
+}
+
+// Returns the first OpSelectionMerge instruction found in |MBB|, nullptr
+// otherwise.
+MachineInstr *getSelectionMergeInstruction(MachineBasicBlock &MBB) {
+  for (auto &I : MBB) {
+    if (I.getOpcode() == SPIRV::OpSelectionMerge)
+      return &I;
+  }
+  return nullptr;
+}
+
+// Traverses the CFG DFS-style starting from the entry point.
+// Calls |op| on each basic block encountered during the traversal.
+void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
+  std::stack<MachineBasicBlock *> ToVisit;
+  SmallPtrSet<MachineBasicBlock *, 8> Seen;
+
+  ToVisit.push(&*MF.begin());
+  Seen.insert(ToVisit.top());
+  while (ToVisit.size() != 0) {
+    MachineBasicBlock *MBB = ToVisit.top();
+    ToVisit.pop();
+
+    op(MBB);
+
+    for (auto Succ : MBB->successors()) {
+      if (Seen.contains(Succ))
+        continue;
+      ToVisit.push(Succ);
+      Seen.insert(Succ);
+    }
+  }
+}
+
+// Returns all basic blocks in |MF| with at least one SelectionMerge/LoopMerge
+// instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getHeaderBlocks(MachineFunction &MF) {
+  SmallPtrSet<MachineBasicBlock *, 8> Output;
+  for (MachineBasicBlock &MBB : MF) {
+    auto *MI = getMergeInstruction(MBB);
+    if (MI != nullptr)
+      Output.insert(&MBB);
+  }
+  return Output;
+}
+
+// Returns all basic blocks in |MF| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<MachineBasicBlock *, 8> getMergeBlocks(MachineFunction &MF) {
+  SmallPtrSet<MachineBasicBlock *, 8> Output;
+  for (MachineBasicBlock &MBB : MF) {
+    auto *MI = getMergeInstruction(MBB);
+    if (MI != nullptr)
+      Output.insert(MI->getOperand(0).getMBB());
+  }
+  return Output;
+}
+
+// Returns all basic blocks in |MF| referenced as continue target by at least 1
+// OpLoopMerge.
+SmallPtrSet<MachineBasicBlock *, 8> getContinueBlocks(MachineFunction &MF) {
+  SmallPtrSet<MachineBasicBlock *, 8> Output;
+  for (MachineBasicBlock &MBB : MF) {
+    auto *MI = getMergeInstruction(MBB);
+    if (MI != nullptr && MI->getOpcode() == SPIRV::OpLoopMerge)
+      Output.insert(MI->getOperand(1).getMBB());
+  }
+  return Output;
+}
+
+// Returns the block immediatly post-dominating every block in |Range| if any,
+// nullptr otherwise.
+MachineBasicBlock *findNearestCommonDominator(
+    const iterator_range<std::vector<MachineBasicBlock *>::iterator> &Range,
+    MachinePostDominatorTree &MPDT) {
+  assert(!Range.empty());
+  MachineBasicBlock *Dom = *Range.begin();
+  for (MachineBasicBlock *Item : Range)
+    Dom = MPDT.findNearestCommonDominator(Dom, Item);
+  return Dom;
+}
+
+// Finds the first merge instruction in |MBB| and store it in |MI|.
+// If it defines a merge target, sets |Merge| to the merge target.
+// If it defines a continue target, sets |Continue| to the continue target.
+// Returns true if such merge instruction was found, false otherwise.
+bool getMergeInstructionTargets(MachineBasicBlock *MBB, MachineInstr **MI,
+                                MachineBasicBlock **Merge,
+                                MachineBasicBlock **Continue) {
+  *Merge = nullptr;
+  *Continue = nullptr;
+
+  *MI = getMergeInstruction(*MBB);
+  if (*MI == nullptr)
+    return false;
+
+  *Merge = (*MI)->getOperand(0).getMBB();
+  *Continue = (*MI)->getOpcode() == SPIRV::OpLoopMerge
+                  ? (*MI)->getOperand(1).getMBB()
+                  : nullptr;
+  return true;
+}
+
+} // anonymous namespace
+
+class SPIRVStructurizer : public MachineFunctionPass {
+public:
+  static char ID;
+
+  SPIRVStructurizer() : MachineFunctionPass(ID) {
+    initializeSPIRVStructurizerPass(*PassRegistry::getPassRegistry());
+  };
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    MachineFunctionPass::getAnalysisUsage(AU);
+    AU.addRequired<MachineLoopInfoWrapperPass>();
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+  }
+
+  // Creates a MachineBasicBlock ending with an OpReturn just after
+  // |Predecessor|. This function does not add any branch to |Predecessor|, but
+  // adds the new block to its successors.
+  MachineBasicBlock *createReturnBlock(MachineFunction &MF,
+                                       MachineBasicBlock &Predecessor) {
+    MachineBasicBlock *MBB =
+        MF.CreateMachineBasicBlock(Predecessor.getBasicBlock());
+    MF.push_back(MBB);
+    MBB->moveAfter(&Predecessor);
+    // This code doesn't add a branch instruction to this new return block.
+    // The caller will have to handle that.
+    Predecessor.addSuccessorWithoutProb(MBB);
+
+    MachineIRBuilder MIRBuilder(MF);
+    MIRBuilder.setInsertPt(*MBB, MBB->end());
+    MIRBuilder.buildInstr(SPIRV::OpReturn);
+
+    return MBB;
+  }
+
+  // Replace switches with a single target with an unconditional branch.
+  bool replaceEmptySwitchWithBranch(MachineFunction &MF) {
+    bool Modified = false;
+    for (MachineBasicBlock &MBB : MF) {
+      MachineInstr *I = &*MBB.rbegin();
+      GIntrinsic *II = dyn_cast<GIntrinsic>(I);
+      if (!II || II->getIntrinsicID() != Intrinsic::spv_switch ||
+          II->getNumOperands() > 3)
+        continue;
+
+      Modified = true;
+      assert(II->getOperand(2).isMBB());
+      MachineBasicBlock *Target = II->getOperand(2).getMBB();
+
+      MachineIRBuilder MIRBuilder(MF);
+      MIRBuilder.setInsertPt(MBB, MBB.end());
+      MIRBuilder.buildBr(*Target);
+      MBB.erase(I);
+    }
+
+    return Modified;
+  }
+
+  // Traverse each loop, and adds an OpLoopMerge instruction to its header
+  // that respect the convergence region node it belongs to.
+  // The Continue target is the only back-edge in that loop.
+  // The merge target is the only exiting node of the convergence region.
+  bool addMergeForLoops(MachineFunction &MF) {
+    auto &TII = *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+    auto &TRI = *MF.getSubtarget<SPIRVSubtarget>().getRegisterInfo();
+    auto &RBI = *MF.getSubtarget<SPIRVSubtarget>().getRegBankInfo();
+
+    const auto &MLI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
+    const auto *TopLevelRegion =
+        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+            .getRegionInfo()
+            .getTopLevelRegion();
+
+    bool Modified = false;
+    for (auto &MBB : MF) {
+      // Not a loop header. Ignoring for now.
+      if (!MLI.isLoopHeader(&MBB))
+        continue;
+      auto *L = MLI.getLoopFor(&MBB);
+
+      // This loop header is not the entrance of a convergence region. Ignoring
+      // this block.
+      auto *CR = getRegionForHeader(TopLevelRegion, &MBB);
+      if (CR == nullptr)
+        continue;
+
+      auto *Merge = getExitFor(MF, CR);
+      // This is a special case:
+      // We are indeed in a loop, but there are no exits (infinite loop).
+      // This means the actual branch is unconditional, hence won't require any
+      // OpLoopMerge.
+      if (Merge == nullptr) {
+        Merge = createReturnBlock(MF, MBB);
+      }
+
+      auto *Continue = L->getLoopLatch();
+
+      // Conditional branch are built using a fallthrough if false + BR.
+      // So the last instruction is not always the first branch.
+      auto *I = &*MBB.getFirstTerminator();
+      BuildMI(MBB, I, I->getDebugLoc(), TII.get(SPIRV::OpLoopMerge))
+          .addMBB(Merge)
+          .addMBB(Continue)
+          .addImm(SPIRV::SelectionControl::None)
+          .constrainAllUses(TII, TRI, RBI);
+      Modified = true;
+    }
+
+    return Modified;
+  }
+
+  // Add an OpSelectionMerge to each node with an out-degree of 2 or more.
+  bool addMergeForConditionalBranches(MachineFunction &MF) {
----------------
s-perron wrote:

How do you make sure that all blocks with in-degree of 2 or more are the merge block of a header? The example you give below is a good example?

```
//          A
//         / \
//        B<--C
//        \    \
//         \    D
//          \  /
//           E
```

>From what I can tell, this will add two `OpSelectionMerge` instructions. One in `A` and one in `C`. Both will use `E` as the merge block. `B` will not be a merge block. This violates the rules for maximal reconvergence (https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_maximal_reconvergence.html).

Note, we want to use maximal reconvergence now that it is released. Otherwise, many cases would have implementation defined behaviour.

https://github.com/llvm/llvm-project/pull/97606


More information about the llvm-commits mailing list