[clang] [llvm] [SPIR-V] Add SPIR-V structurizer (PR #107408)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 9 02:18:41 PDT 2024


================
@@ -0,0 +1,1410 @@
+//===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+#include <queue>
+#include <stack>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+
+void initializeSPIRVStructurizerPass(PassRegistry &);
+
+namespace {
+
+using BlockSet = std::unordered_set<BasicBlock *>;
+using Edge = std::pair<BasicBlock *, BasicBlock *>;
+
+// This class implements a partial ordering visitor, which visits a cyclic graph
+// in natural topological-like ordering. Topological ordering is not defined for
+// directed graphs with cycles, so this assumes cycles are a single node, and
+// ignores back-edges. The cycle is visited from the entry in the same
+// topological-like ordering.
+//
+// This means once we visit a node, we know all the possible ancestors have been
+// visited.
+//
+// clang-format off
+//
+// Given this graph:
+//
+//     ,-> B -\
+// A -+        +---> D ----> E -> F -> G -> H
+//     `-> C -/      ^                 |
+//                   +-----------------+
+//
+// Visit order is:
+//  A, [B, C in any order], D, E, F, G, H
+//
+// clang-format on
+//
+// Changing the function CFG between the construction of the visitor and
+// visiting is undefined. The visitor can be reused, but if the CFG is updated,
+// the visitor must be rebuilt.
+class PartialOrderingVisitor {
+  DomTreeBuilder::BBDomTree DT;
+  LoopInfo LI;
+  BlockSet Visited;
+  std::unordered_map<BasicBlock *, size_t> B2R;
+  std::vector<std::pair<BasicBlock *, size_t>> Order;
+
+  // Get all basic-blocks reachable from Start.
+  BlockSet getReachableFrom(BasicBlock *Start) {
+    std::queue<BasicBlock *> ToVisit;
+    ToVisit.push(Start);
+
+    BlockSet Output;
+    while (ToVisit.size() != 0) {
+      BasicBlock *BB = ToVisit.front();
+      ToVisit.pop();
+
+      if (Output.count(BB) != 0)
+        continue;
+      Output.insert(BB);
+
+      for (BasicBlock *Successor : successors(BB)) {
+        if (DT.dominates(Successor, BB))
+          continue;
+        ToVisit.push(Successor);
+      }
+    }
+
+    return Output;
+  }
+
+  size_t visit(BasicBlock *BB, size_t Rank) {
+    if (Visited.count(BB) != 0)
+      return Rank;
+
+    Loop *L = LI.getLoopFor(BB);
+    const bool isLoopHeader = LI.isLoopHeader(BB);
+
+    if (B2R.count(BB) == 0) {
+      B2R.emplace(BB, Rank);
+    } else {
+      B2R[BB] = std::max(B2R[BB], Rank);
+    }
+
+    for (BasicBlock *Predecessor : predecessors(BB)) {
+      if (isLoopHeader && L->contains(Predecessor)) {
+        continue;
+      }
+
+      if (B2R.count(Predecessor) == 0) {
+        return Rank;
+      }
+    }
+
+    Visited.insert(BB);
+
+    SmallVector<BasicBlock *, 2> OtherSuccessors;
+    BasicBlock *LoopSuccessor = nullptr;
+
+    for (BasicBlock *Successor : successors(BB)) {
+      // Ignoring back-edges.
+      if (DT.dominates(Successor, BB))
+        continue;
+
+      if (isLoopHeader && L->contains(Successor)) {
+        assert(LoopSuccessor == nullptr);
+        LoopSuccessor = Successor;
+      } else
+        OtherSuccessors.push_back(Successor);
+    }
+
+    if (LoopSuccessor)
+      Rank = visit(LoopSuccessor, Rank + 1);
+
+    size_t OutputRank = Rank;
+    for (BasicBlock *Item : OtherSuccessors)
+      OutputRank = std::max(OutputRank, visit(Item, Rank + 1));
+    return OutputRank;
+  };
+
+public:
+  // Build the visitor to operate on the function F.
+  PartialOrderingVisitor(Function &F) {
+    DT.recalculate(F);
+    LI = LoopInfo(DT);
+
+    visit(&*F.begin(), 0);
+
+    for (auto &[BB, Rank] : B2R)
+      Order.emplace_back(BB, Rank);
+
+    std::sort(Order.begin(), Order.end(), [](const auto &LHS, const auto &RHS) {
+      return LHS.second < RHS.second;
+    });
+
+    for (size_t i = 0; i < Order.size(); i++)
+      B2R[Order[i].first] = i;
+  }
+
+  // Visit the function starting from the basic block |Start|, and calling |Op|
+  // on each visited BB. This traversal ignores back-edges, meaning this won't
+  // visit a node to which |Start| is not an ancestor.
+  void partialOrderVisit(BasicBlock &Start,
+                         std::function<bool(BasicBlock *)> Op) {
+    BlockSet Reachable = getReachableFrom(&Start);
+    assert(B2R.count(&Start) != 0);
+    size_t Rank = Order[B2R[&Start]].second;
+
+    auto It = Order.begin();
+    while (It != Order.end() && It->second < Rank)
+      ++It;
+
+    if (It == Order.end())
+      return;
+
+    size_t EndRank = Order.rbegin()->second + 1;
+    for (; It != Order.end() && It->second <= EndRank; ++It) {
+      if (Reachable.count(It->first) == 0) {
+        continue;
+      }
+
+      if (!Op(It->first)) {
+        EndRank = It->second;
+      }
+    }
+  }
+};
+
+// Helper function to do a partial order visit from the block |Start|, calling
+// |Op| on each visited node.
+void partialOrderVisit(BasicBlock &Start,
+                       std::function<bool(BasicBlock *)> Op) {
+  PartialOrderingVisitor V(*Start.getParent());
+  V.partialOrderVisit(Start, Op);
+}
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `BB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+                                            BasicBlock *BB) {
+  if (Node->Entry == BB)
+    return Node;
+
+  for (auto *Child : Node->Children) {
+    const auto *CR = getRegionForHeader(Child, BB);
+    if (CR != nullptr)
+      return CR;
+  }
+  return nullptr;
+}
+
+// Returns the single BasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists. F must be the function CR belongs to.
+BasicBlock *getExitFor(const ConvergenceRegion *CR) {
+  std::unordered_set<BasicBlock *> ExitTargets;
+  for (BasicBlock *Exit : CR->Exits) {
+    for (BasicBlock *Successor : successors(Exit)) {
+      if (CR->Blocks.count(Successor) == 0)
+        ExitTargets.insert(Successor);
+    }
+  }
+
+  assert(ExitTargets.size() <= 1);
+  if (ExitTargets.size() == 0)
+    return nullptr;
+
+  return *ExitTargets.begin();
+}
+
+// Returns the merge block designated by I if I is a merge instruction, nullptr
+// otherwise.
+BasicBlock *getDesignatedMergeBlock(Instruction *I) {
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+  if (II == nullptr)
+    return nullptr;
+
+  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
+      II->getIntrinsicID() != Intrinsic::spv_selection_merge)
+    return nullptr;
+
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
+  return BA->getBasicBlock();
+}
+
+// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
+// otherwise.
+BasicBlock *getDesignatedContinueBlock(Instruction *I) {
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+  if (II == nullptr)
+    return nullptr;
+
+  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
+    return nullptr;
+
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
+  return BA->getBasicBlock();
+}
+
+// Returns true if Header has one merge instruction which designated Merge as
+// merge block.
+bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
+  for (auto &I : Header) {
+    BasicBlock *MB = getDesignatedMergeBlock(&I);
+    if (MB == &Merge)
+      return true;
+  }
+  return false;
+}
+
+// Returns true if the BB has one OpLoopMerge instruction.
+bool hasLoopMergeInstruction(BasicBlock &BB) {
+  for (auto &I : BB)
+    if (getDesignatedContinueBlock(&I))
+      return true;
+  return false;
+}
+
+// Returns truye is I is an OpSelectionMerge or OpLoopMerge instruction, false
+// otherwise.
+bool isMergeInstruction(Instruction *I) {
+  return getDesignatedMergeBlock(I) != nullptr;
+}
+
+// Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
+// instruction.
+SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      if (getDesignatedMergeBlock(&I) != nullptr)
+        Output.insert(&BB);
+    }
+  }
+  return Output;
+}
+
+// Returns all basic blocks in |F| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      BasicBlock *MB = getDesignatedMergeBlock(&I);
+      if (MB != nullptr)
+        Output.insert(MB);
+    }
+  }
+  return Output;
+}
+
+// Return all the merge instructions contained in BB.
+// Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
+// instruction, but this can happen while we structurize the CFG.
+std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
+  std::vector<Instruction *> Output;
+  for (Instruction &I : BB)
+    if (isMergeInstruction(&I))
+      Output.push_back(&I);
+  return Output;
+}
+
+// Returns all basic blocks in |F| referenced as continue target by at least 1
+// OpLoopMerge instruction.
+SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      BasicBlock *MB = getDesignatedContinueBlock(&I);
+      if (MB != nullptr)
+        Output.insert(MB);
+    }
+  }
+  return Output;
+}
+
+// Do a preorder traversal of the CFG starting from the BB |Start|.
+// point. Calls |op| on each basic block encountered during the traversal.
+void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
+  std::stack<BasicBlock *> ToVisit;
+  SmallPtrSet<BasicBlock *, 8> Seen;
+
+  ToVisit.push(&Start);
+  Seen.insert(ToVisit.top());
+  while (ToVisit.size() != 0) {
+    BasicBlock *BB = ToVisit.top();
+    ToVisit.pop();
+
+    if (!op(BB))
+      continue;
+
+    for (auto Succ : successors(BB)) {
+      if (Seen.contains(Succ))
+        continue;
+      ToVisit.push(Succ);
+      Seen.insert(Succ);
+    }
+  }
+}
+
+// Replaces the conditional and unconditional branch targets of |BB| by
+// |NewTarget| if the target was |OldTarget|. This function also makes sure the
+// associated merge instruction gets updated accordingly.
+void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
+                            BasicBlock *NewTarget) {
+  auto *BI = cast<BranchInst>(BB->getTerminator());
+
+  // 1. Replace all matching successors.
+  for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+    if (BI->getSuccessor(i) == OldTarget)
+      BI->setSuccessor(i, NewTarget);
+  }
+
+  // Branch was unconditional, no fixup required.
+  if (BI->isUnconditional())
+    return;
+
+  // Branch had 2 successors, maybe now both are the same?
+  if (BI->getSuccessor(0) != BI->getSuccessor(1))
+    return;
+
+  // Note: we may end up here because the original IR had such branches.
+  // This means Target is not necessarily equal to NewTarget.
+  IRBuilder<> Builder(BB);
+  Builder.SetInsertPoint(BI);
+  Builder.CreateBr(BI->getSuccessor(0));
+  BI->eraseFromParent();
+
+  // The branch was the only instruction, nothing else to do.
+  if (BB->size() == 1)
+    return;
+
+  // Otherwise, we need to check: was there an OpSelectionMerge before this
+  // branch? If we removed the OpBranchConditional, we must also remove the
+  // OpSelectionMerge. This is not valid for OpLoopMerge:
+  IntrinsicInst *II =
+      dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
+  if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
+    return;
+
+  Constant *C = cast<Constant>(II->getOperand(0));
+  II->eraseFromParent();
+  if (!C->isConstantUsed())
+    C->destroyConstant();
+}
+
+// Replaces the branching instruction destination of |BB| by |NewTarget| if it
+// was |OldTarget|. This function also fixes the associated merge instruction.
+// Note: this function does not simplify branching instructions, it only updates
+// targets. See also: simplifyBranches.
+void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
+                          BasicBlock *NewTarget) {
+  auto *T = BB->getTerminator();
+  if (isa<ReturnInst>(T))
+    return;
+
+  if (isa<BranchInst>(T))
+    return replaceIfBranchTargets(BB, OldTarget, NewTarget);
+
+  if (auto *SI = dyn_cast<SwitchInst>(T)) {
+    for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+      if (SI->getSuccessor(i) == OldTarget)
+        SI->setSuccessor(i, NewTarget);
+    }
+    return;
+  }
+
+  assert(false && "Unhandled terminator type.");
+}
+
+// Replaces basic bloc operands |OldSrc| or OpPhi instructions in |BB| by
+// |NewSrc|. This function does not simplifies the OpPhi instruction once
+// transformed.
+void replacePhiTargets(BasicBlock *BB, BasicBlock *OldSrc, BasicBlock *NewSrc) {
+  for (PHINode &Phi : BB->phis()) {
+    int index = Phi.getBasicBlockIndex(OldSrc);
+    if (index == -1)
+      continue;
+    Phi.setIncomingBlock(index, NewSrc);
+  }
+}
+
+} // anonymous namespace
+
+// Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
+// adding merge instructions when required.
+class SPIRVStructurizer : public FunctionPass {
+
+  struct DivergentConstruct;
+  // Represents a list of condition/loops/switch constructs.
+  // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
+  // constructs.
+  using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
+
+  // Represents a divergent construct in the SPIR-V sense.
+  // Such construct is represented by a header (entry), a merge block (exit),
+  // and possible a continue block (back-edge). Each construct can contain other
+  // constructs, but their boundaries do not cross.
+  struct DivergentConstruct {
+    BasicBlock *Header = nullptr;
+    BasicBlock *Merge = nullptr;
+    BasicBlock *Continue = nullptr;
+
+    DivergentConstruct *Parent = nullptr;
+    ConstructList Children;
+  };
+
+  // An helper class to clean the construct boundaries.
+  // It is used to gather the list of blocks that should belong to each
+  // divergent construct, and possibly modify CFG edges when exits would cross
+  // the boundary of multiple constructs.
+  struct Splitter {
+    Function &F;
+    LoopInfo &LI;
+    DomTreeBuilder::BBDomTree DT;
+    DomTreeBuilder::BBPostDomTree PDT;
+
+    Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
+
+    void invalidate() {
+      PDT.recalculate(F);
+      DT.recalculate(F);
+    }
+
+    // Returns the list of blocks that belongs to a SPIR-V continue construct.
+    std::vector<BasicBlock *> getContinueConstructBlocks(BasicBlock *Header,
+                                                         BasicBlock *Continue) {
+      std::vector<BasicBlock *> Output;
+      Loop *L = LI.getLoopFor(Continue);
+      BasicBlock *BackEdgeBlock = L->getLoopLatch();
+      assert(BackEdgeBlock);
----------------
Keenuts wrote:

Changed this to check for `assert(L->getLoopLatch() != nullptr);`.
What I wanted is to make sure we do have a back-edge block before traversing the tree.

https://github.com/llvm/llvm-project/pull/107408


More information about the cfe-commits mailing list