[clang] [llvm] [SPIR-V] Add SPIR-V structurizer (PR #107408)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 19 07:46:17 PDT 2024


================
@@ -0,0 +1,1242 @@
+//===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+#include <queue>
+#include <stack>
+#include <unordered_set>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+
+void initializeSPIRVStructurizerPass(PassRegistry &);
+
+namespace {
+
+using BlockSet = std::unordered_set<BasicBlock *>;
+using Edge = std::pair<BasicBlock *, BasicBlock *>;
+
+// Helper function to do a partial order visit from the block |Start|, calling
+// |Op| on each visited node.
+void partialOrderVisit(BasicBlock &Start,
+                       std::function<bool(BasicBlock *)> Op) {
+  PartialOrderingVisitor V(*Start.getParent());
+  V.partialOrderVisit(Start, Op);
+}
+
+// Returns the exact convergence region in the tree defined by `Node` for which
+// `BB` is the header, nullptr otherwise.
+const ConvergenceRegion *getRegionForHeader(const ConvergenceRegion *Node,
+                                            BasicBlock *BB) {
+  if (Node->Entry == BB)
+    return Node;
+
+  for (auto *Child : Node->Children) {
+    const auto *CR = getRegionForHeader(Child, BB);
+    if (CR != nullptr)
+      return CR;
+  }
+  return nullptr;
+}
+
+// Returns the single BasicBlock exiting the convergence region `CR`,
+// nullptr if no such exit exists.
+BasicBlock *getExitFor(const ConvergenceRegion *CR) {
+  std::unordered_set<BasicBlock *> ExitTargets;
+  for (BasicBlock *Exit : CR->Exits) {
+    for (BasicBlock *Successor : successors(Exit)) {
+      if (CR->Blocks.count(Successor) == 0)
+        ExitTargets.insert(Successor);
+    }
+  }
+
+  assert(ExitTargets.size() <= 1);
+  if (ExitTargets.size() == 0)
+    return nullptr;
+
+  return *ExitTargets.begin();
+}
+
+// Returns the merge block designated by I if I is a merge instruction, nullptr
+// otherwise.
+BasicBlock *getDesignatedMergeBlock(Instruction *I) {
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+  if (II == nullptr)
+    return nullptr;
+
+  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
+      II->getIntrinsicID() != Intrinsic::spv_selection_merge)
+    return nullptr;
+
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
+  return BA->getBasicBlock();
+}
+
+// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
+// otherwise.
+BasicBlock *getDesignatedContinueBlock(Instruction *I) {
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+  if (II == nullptr)
+    return nullptr;
+
+  if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
+    return nullptr;
+
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
+  return BA->getBasicBlock();
+}
+
+// Returns true if Header has one merge instruction which designated Merge as
+// merge block.
+bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
+  for (auto &I : Header) {
+    BasicBlock *MB = getDesignatedMergeBlock(&I);
+    if (MB == &Merge)
+      return true;
+  }
+  return false;
+}
+
+// Returns true if the BB has one OpLoopMerge instruction.
+bool hasLoopMergeInstruction(BasicBlock &BB) {
+  for (auto &I : BB)
+    if (getDesignatedContinueBlock(&I))
+      return true;
+  return false;
+}
+
+// Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
+// otherwise.
+bool isMergeInstruction(Instruction *I) {
+  return getDesignatedMergeBlock(I) != nullptr;
+}
+
+// Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
+// instruction.
+SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      if (getDesignatedMergeBlock(&I) != nullptr)
+        Output.insert(&BB);
+    }
+  }
+  return Output;
+}
+
+// Returns all basic blocks in |F| referenced by at least 1
+// OpSelectionMerge/OpLoopMerge instruction.
+SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      BasicBlock *MB = getDesignatedMergeBlock(&I);
+      if (MB != nullptr)
+        Output.insert(MB);
+    }
+  }
+  return Output;
+}
+
+// Return all the merge instructions contained in BB.
+// Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
+// instruction, but this can happen while we structurize the CFG.
+std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
+  std::vector<Instruction *> Output;
+  for (Instruction &I : BB)
+    if (isMergeInstruction(&I))
+      Output.push_back(&I);
+  return Output;
+}
+
+// Returns all basic blocks in |F| referenced as continue target by at least 1
+// OpLoopMerge instruction.
+SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
+  SmallPtrSet<BasicBlock *, 2> Output;
+  for (BasicBlock &BB : F) {
+    for (Instruction &I : BB) {
+      BasicBlock *MB = getDesignatedContinueBlock(&I);
+      if (MB != nullptr)
+        Output.insert(MB);
+    }
+  }
+  return Output;
+}
+
+// Do a preorder traversal of the CFG starting from the BB |Start|.
+// point. Calls |op| on each basic block encountered during the traversal.
+void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
+  std::stack<BasicBlock *> ToVisit;
+  SmallPtrSet<BasicBlock *, 8> Seen;
+
+  ToVisit.push(&Start);
+  Seen.insert(ToVisit.top());
+  while (ToVisit.size() != 0) {
+    BasicBlock *BB = ToVisit.top();
+    ToVisit.pop();
+
+    if (!op(BB))
+      continue;
+
+    for (auto Succ : successors(BB)) {
+      if (Seen.contains(Succ))
+        continue;
+      ToVisit.push(Succ);
+      Seen.insert(Succ);
+    }
+  }
+}
+
+// Replaces the conditional and unconditional branch targets of |BB| by
+// |NewTarget| if the target was |OldTarget|. This function also makes sure the
+// associated merge instruction gets updated accordingly.
+void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
+                            BasicBlock *NewTarget) {
+  auto *BI = cast<BranchInst>(BB->getTerminator());
+
+  // 1. Replace all matching successors.
+  for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+    if (BI->getSuccessor(i) == OldTarget)
+      BI->setSuccessor(i, NewTarget);
+  }
+
+  // Branch was unconditional, no fixup required.
+  if (BI->isUnconditional())
+    return;
+
+  // Branch had 2 successors, maybe now both are the same?
+  if (BI->getSuccessor(0) != BI->getSuccessor(1))
+    return;
+
+  // Note: we may end up here because the original IR had such branches.
+  // This means Target is not necessarily equal to NewTarget.
+  IRBuilder<> Builder(BB);
+  Builder.SetInsertPoint(BI);
+  Builder.CreateBr(BI->getSuccessor(0));
+  BI->eraseFromParent();
+
+  // The branch was the only instruction, nothing else to do.
+  if (BB->size() == 1)
+    return;
+
+  // Otherwise, we need to check: was there an OpSelectionMerge before this
+  // branch? If we removed the OpBranchConditional, we must also remove the
+  // OpSelectionMerge. This is not valid for OpLoopMerge:
+  IntrinsicInst *II =
+      dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
+  if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
+    return;
+
+  Constant *C = cast<Constant>(II->getOperand(0));
+  II->eraseFromParent();
+  if (!C->isConstantUsed())
+    C->destroyConstant();
+}
+
+// Replaces the target of branch instruction in |BB| with |NewTarget| if it
+// was |OldTarget|. This function also fixes the associated merge instruction.
+// Note: this function does not simplify branching instructions, it only updates
+// targets. See also: simplifyBranches.
+void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
+                          BasicBlock *NewTarget) {
+  auto *T = BB->getTerminator();
+  if (isa<ReturnInst>(T))
+    return;
+
+  if (isa<BranchInst>(T))
+    return replaceIfBranchTargets(BB, OldTarget, NewTarget);
+
+  if (auto *SI = dyn_cast<SwitchInst>(T)) {
+    for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+      if (SI->getSuccessor(i) == OldTarget)
+        SI->setSuccessor(i, NewTarget);
+    }
+    return;
+  }
+
+  assert(false && "Unhandled terminator type.");
+}
+
+// Replaces basic bloc operands |OldSrc| or OpPhi instructions in |BB| by
+// |NewSrc|. This function does not simplify the OpPhi instruction once
+// transformed.
+void replacePhiTargets(BasicBlock *BB, BasicBlock *OldSrc, BasicBlock *NewSrc) {
+  for (PHINode &Phi : BB->phis()) {
+    int index = Phi.getBasicBlockIndex(OldSrc);
+    if (index == -1)
+      continue;
+    Phi.setIncomingBlock(index, NewSrc);
+  }
+}
+
+} // anonymous namespace
+
+// Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
+// adding merge instructions when required.
+class SPIRVStructurizer : public FunctionPass {
+
+  struct DivergentConstruct;
+  // Represents a list of condition/loops/switch constructs.
+  // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
+  // constructs.
+  using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
+
+  // Represents a divergent construct in the SPIR-V sense.
+  // Such constructs are represented by a header (entry), a merge block (exit),
+  // and possibly a continue block (back-edge). A construct can contain other
+  // constructs, but their boundaries do not cross.
+  struct DivergentConstruct {
+    BasicBlock *Header = nullptr;
+    BasicBlock *Merge = nullptr;
+    BasicBlock *Continue = nullptr;
+
+    DivergentConstruct *Parent = nullptr;
+    ConstructList Children;
+  };
+
+  // An helper class to clean the construct boundaries.
+  // It is used to gather the list of blocks that should belong to each
+  // divergent construct, and possibly modify CFG edges when exits would cross
+  // the boundary of multiple constructs.
+  struct Splitter {
+    Function &F;
+    LoopInfo &LI;
+    DomTreeBuilder::BBDomTree DT;
+    DomTreeBuilder::BBPostDomTree PDT;
+
+    Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
+
+    void invalidate() {
+      PDT.recalculate(F);
+      DT.recalculate(F);
+    }
+
+    // Returns the list of blocks that belong to a SPIR-V loop construct,
+    // including the continue construct.
+    std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
+                                                     BasicBlock *Merge) {
+      assert(DT.dominates(Header, Merge));
+      std::vector<BasicBlock *> Output;
+      partialOrderVisit(*Header, [&](BasicBlock *BB) {
+        if (BB == Merge)
+          return false;
+        if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB))
+          return false;
+        Output.push_back(BB);
+        return true;
+      });
+      return Output;
+    }
+
+    // Returns the list of blocks that belong to a SPIR-V selection construct.
+    std::vector<BasicBlock *>
+    getSelectionConstructBlocks(DivergentConstruct *Node) {
+      assert(DT.dominates(Node->Header, Node->Merge));
+      BlockSet OutsideBlocks;
+      OutsideBlocks.insert(Node->Merge);
+
+      for (DivergentConstruct *It = Node->Parent; It != nullptr;
+           It = It->Parent) {
+        OutsideBlocks.insert(It->Merge);
+        if (It->Continue)
+          OutsideBlocks.insert(It->Continue);
+      }
+
+      std::vector<BasicBlock *> Output;
+      partialOrderVisit(*Node->Header, [&](BasicBlock *BB) {
+        if (OutsideBlocks.count(BB) != 0)
+          return false;
+        if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
+          return false;
+        Output.push_back(BB);
+        return true;
+      });
+      return Output;
+    }
+
+    // Returns the list of blocks that belong to a SPIR-V switch construct.
+    std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
+                                                       BasicBlock *Merge) {
+      assert(DT.dominates(Header, Merge));
+
+      std::vector<BasicBlock *> Output;
+      partialOrderVisit(*Header, [&](BasicBlock *BB) {
+        // the blocks structurally dominated by a switch header,
+        if (!DT.dominates(Header, BB))
+          return false;
+        // excluding blocks structurally dominated by the switch header’s merge
+        // block.
+        if (DT.dominates(Merge, BB) || BB == Merge)
+          return false;
+        Output.push_back(BB);
+        return true;
+      });
+      return Output;
+    }
+
+    // Returns the list of blocks that belong to a SPIR-V case construct.
+    std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
+                                                     BasicBlock *Merge) {
+      assert(DT.dominates(Target, Merge));
+
+      std::vector<BasicBlock *> Output;
+      partialOrderVisit(*Target, [&](BasicBlock *BB) {
+        // the blocks structurally dominated by an OpSwitch Target or Default
+        // block
+        if (!DT.dominates(Target, BB))
+          return false;
+        // excluding the blocks structurally dominated by the OpSwitch
+        // construct’s corresponding merge block.
+        if (DT.dominates(Merge, BB) || BB == Merge)
+          return false;
+        Output.push_back(BB);
+        return true;
+      });
+      return Output;
+    }
+
+    // Splits the given edges by recreating proxy nodes so that the destination
+    // OpPhi instruction can still be viable.
+    //
+    // clang-format off
+    //
+    // In SPIR-V, constructs must have a single exit/merge.
+    // Given nodes A and B in the construct, a node C outside, and the following edges.
+    //  A -> C
+    //  B -> C
+    //
+    // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
+    // A -> D -> C
+    // B -> D -> C
+    //
+    // But if C had a phi node, adding such proxy-block breaks it. In such case, we must add 1 new block per
+    // exit, and patchup the phi node:
+    // A -> D -> D1 -> C
+    // B -> D -> D2 -> C
+    //
+    // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty, just used as
+    // source operands for C's phi node.
+    //
+    // clang-format on
+    std::vector<Edge>
+    createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
+      std::unordered_map<BasicBlock *, BasicBlock *> Seen;
+      std::vector<Edge> Output;
+
+      for (auto &[Src, Dst] : Edges) {
+        auto [iterator, inserted] = Seen.insert({Src, Dst});
+        if (inserted) {
+          Output.emplace_back(Src, Dst);
+          continue;
+        }
+
+        // The exact same edge was already seen. Ignoring.
+        if (iterator->second == Dst)
+          continue;
+
+        // The same Src block branches to 2 distinct blocks. This will be an
+        // issue for the generated OpPhi. Creating alias block.
+        BasicBlock *NewSrc =
+            BasicBlock::Create(F.getContext(), "new.exit.src", &F);
+        replaceBranchTargets(Src, Dst, NewSrc);
+        replacePhiTargets(Dst, Src, NewSrc);
+
+        IRBuilder<> Builder(NewSrc);
+        Builder.CreateBr(Dst);
+
+        Seen.emplace(NewSrc, Dst);
+        Output.emplace_back(NewSrc, Dst);
+      }
+
+      return Output;
+    }
+
+    // Given a construct defined by |Header|, and a list of exiting edges
+    // |Edges|, creates a new single exit node, fixing up those edges.
+    BasicBlock *createSingleExitNode(BasicBlock *Header,
+                                     std::vector<Edge> &Edges) {
+      auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
+      IRBuilder<> ExitBuilder(NewExit);
+
+      std::vector<BasicBlock *> Dsts;
+      std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
+
+      // Given 2 edges: Src1 -> Dst, Src2 -> Dst:
+      // If Dst has an PHI node, and Src1 and Src2 are both operands, both Src1
+      // and Src2 cannot be hidden by NewExit. Create 2 new nodes: Alias1,
+      // Alias2 to which NewExit will branch before going to Dst. Then, patchup
+      // Dst PHI node to look for Alias1 and Alias2.
+      std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
+
+      for (auto &[Src, Dst] : FixedEdges) {
+        if (DstToIndex.count(Dst) != 0)
+          continue;
+        DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
+        Dsts.push_back(Dst);
+      }
+
+      if (Dsts.size() == 1) {
+        for (auto &[Src, Dst] : FixedEdges) {
+          replaceBranchTargets(Src, Dst, NewExit);
+          replacePhiTargets(Dst, Src, NewExit);
+        }
+        ExitBuilder.CreateBr(Dsts[0]);
+        return NewExit;
+      }
+
+      PHINode *PhiNode =
+          ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(), FixedEdges.size());
+
+      for (auto &[Src, Dst] : FixedEdges) {
+        PhiNode->addIncoming(DstToIndex[Dst], Src);
+        replaceBranchTargets(Src, Dst, NewExit);
+        replacePhiTargets(Dst, Src, NewExit);
+      }
+
+      // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
+      // OpBranch are allowed to exist without a new OpSelectionMerge if one of
+      // the branch is the parent's merge node, while OpSwitches are not.
+      if (Dsts.size() == 2) {
+        Value *Condition = ExitBuilder.CreateCmp(CmpInst::ICMP_EQ,
+                                                 DstToIndex[Dsts[0]], PhiNode);
+        ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
+        return NewExit;
+      }
+
+      SwitchInst *Sw =
+          ExitBuilder.CreateSwitch(PhiNode, Dsts[0], Dsts.size() - 1);
+      for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
+        Sw->addCase(DstToIndex[*It], *It);
+      }
+      return NewExit;
+    }
+  };
+
+  /// Create a value in BB set to the value associated with the branch the block
+  /// terminator will take.
+  Value *createExitVariable(
+      BasicBlock *BB,
+      const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
+    auto *T = BB->getTerminator();
+    if (isa<ReturnInst>(T))
+      return nullptr;
+
+    IRBuilder<> Builder(BB);
+    Builder.SetInsertPoint(T);
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+
+      BasicBlock *LHSTarget = BI->getSuccessor(0);
+      BasicBlock *RHSTarget =
+          BI->isConditional() ? BI->getSuccessor(1) : nullptr;
+
+      Value *LHS = TargetToValue.count(LHSTarget) != 0
+                       ? TargetToValue.at(LHSTarget)
+                       : nullptr;
+      Value *RHS = TargetToValue.count(RHSTarget) != 0
+                       ? TargetToValue.at(RHSTarget)
+                       : nullptr;
+
+      if (LHS == nullptr || RHS == nullptr)
+        return LHS == nullptr ? RHS : LHS;
+      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
+    }
+
+    // TODO: add support for switch cases.
+    llvm_unreachable("Unhandled terminator type.");
+  }
+
+  // Creates a new basic block in F with a single OpUnreachable instruction.
+  BasicBlock *CreateUnreachable(Function &F) {
+    BasicBlock *BB = BasicBlock::Create(F.getContext(), "new.exit", &F);
+    IRBuilder<> Builder(BB);
+    Builder.CreateUnreachable();
+    return BB;
+  }
+
+  // Add OpLoopMerge instruction on cycles.
+  bool addMergeForLoops(Function &F) {
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    auto *TopLevelRegion =
+        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+            .getRegionInfo()
+            .getTopLevelRegion();
+
+    bool Modified = false;
+    for (auto &BB : F) {
+      // Not a loop header. Ignoring for now.
+      if (!LI.isLoopHeader(&BB))
+        continue;
+      auto *L = LI.getLoopFor(&BB);
+
+      // This loop header is not the entrance of a convergence region. Ignoring
+      // this block.
+      auto *CR = getRegionForHeader(TopLevelRegion, &BB);
+      if (CR == nullptr)
+        continue;
+
+      IRBuilder<> Builder(&BB);
+
+      auto *Merge = getExitFor(CR);
+      // We are indeed in a loop, but there are no exits (infinite loop).
+      // This could be caused by a bad shader, but also could be an artifact
+      // from an earlier optimization. It is not always clear if structurally
+      // reachable means runtime reachable, so we cannot error-out. What we must
+      // do however is to make is legal on the SPIR-V point of view, hence
+      // adding an unreachable merge block.
+      if (Merge == nullptr) {
+        BranchInst *Br = cast<BranchInst>(BB.getTerminator());
+        assert(cast<BranchInst>(BB.getTerminator())->isUnconditional());
+
+        Merge = CreateUnreachable(F);
+        Builder.SetInsertPoint(Br);
+        Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0));
+        Br->eraseFromParent();
+      }
+
+      auto *Continue = L->getLoopLatch();
+
+      Builder.SetInsertPoint(BB.getTerminator());
+      auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
+      auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
+      SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
+
+      Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args});
+      Modified = true;
+    }
+
+    return Modified;
+  }
+
+  // Adds an OpSelectionMerge to the immediate dominator or each node with an
+  // in-degree of 2 or more which is not already the merge target of an
+  // OpLoopMerge/OpSelectionMerge.
+  bool addMergeForNodesWithMultiplePredecessors(Function &F) {
+    DomTreeBuilder::BBDomTree DT;
+    DT.recalculate(F);
+
+    bool Modified = false;
+    for (auto &BB : F) {
+      if (pred_size(&BB) <= 1)
+        continue;
+
+      if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2)
+        continue;
+
+      assert(DT.getNode(&BB)->getIDom());
+      BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock();
+
+      if (isDefinedAsSelectionMergeBy(*Header, BB))
+        continue;
+
+      IRBuilder<> Builder(Header);
+      Builder.SetInsertPoint(Header->getTerminator());
+
+      auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
+      SmallVector<Value *, 1> Args = {MergeAddress};
+      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+
+      Modified = true;
+    }
+
+    return Modified;
+  }
+
+  // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
+  // them to put the "largest" first. A merge instruction is defined as larger
+  // than another when its target merge block post-dominates the other target's
+  // merge block. (This ordering should match the nesting ordering of the source
+  // HLSL).
+  bool sortSelectionMerge(Function &F, BasicBlock &Block) {
+    std::vector<Instruction *> MergeInstructions;
+    for (Instruction &I : Block)
+      if (isMergeInstruction(&I))
+        MergeInstructions.push_back(&I);
+
+    if (MergeInstructions.size() <= 1)
+      return false;
+
+    Instruction *InsertionPoint = *MergeInstructions.begin();
+
+    PartialOrderingVisitor Visitor(F);
+    std::sort(MergeInstructions.begin(), MergeInstructions.end(),
+              [&Visitor](Instruction *Left, Instruction *Right) {
+                if (Left == Right)
+                  return false;
+                BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
+                BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
+                return !Visitor.compare(RightMerge, LeftMerge);
+              });
+
+    for (Instruction *I : MergeInstructions) {
+      I->moveBefore(InsertionPoint);
+      InsertionPoint = I;
+    }
+
+    return true;
+  }
+
+  // Sorts selection merge headers in |F|.
+  // A is sorted before B if the merge block designated by B is an ancestor of
+  // the one designated by A.
+  bool sortSelectionMergeHeaders(Function &F) {
+    bool Modified = false;
+    for (BasicBlock &BB : F) {
+      Modified |= sortSelectionMerge(F, BB);
+    }
+    return Modified;
+  }
+
+  // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
+  // instructions so each basic block contains only a single merge instruction.
+  bool splitBlocksWithMultipleHeaders(Function &F) {
+    std::stack<BasicBlock *> Work;
+    for (auto &BB : F) {
+      std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
+      if (MergeInstructions.size() <= 1)
+        continue;
+      Work.push(&BB);
+    }
+
+    const bool Modified = Work.size() > 0;
+    while (Work.size() > 0) {
+      BasicBlock *Header = Work.top();
+      Work.pop();
+
+      std::vector<Instruction *> MergeInstructions =
+          getMergeInstructions(*Header);
+      for (unsigned i = 1; i < MergeInstructions.size(); i++) {
+        BasicBlock *NewBlock =
+            Header->splitBasicBlock(MergeInstructions[i], "new.header");
+
+        if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) {
+          BasicBlock *Unreachable = CreateUnreachable(F);
+
+          BranchInst *BI = cast<BranchInst>(Header->getTerminator());
+          IRBuilder<> Builder(Header);
+          Builder.SetInsertPoint(BI);
+          Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
+          BI->eraseFromParent();
+        }
+
+        Header = NewBlock;
+      }
+    }
+
+    return Modified;
+  }
+
+  // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
+  // doesn't already have an OpSelectionMerge.
+  bool addMergeForDivergentBlocks(Function &F) {
+    DomTreeBuilder::BBPostDomTree PDT;
+    PDT.recalculate(F);
+    bool Modified = false;
+
+    auto MergeBlocks = getMergeBlocks(F);
+    auto ContinueBlocks = getContinueBlocks(F);
+
+    for (auto &BB : F) {
+      if (getMergeInstructions(BB).size() != 0)
+        continue;
+
+      std::vector<BasicBlock *> Candidates;
+      for (BasicBlock *Successor : successors(&BB)) {
+        if (MergeBlocks.contains(Successor))
+          continue;
+        if (ContinueBlocks.contains(Successor))
+          continue;
+        Candidates.push_back(Successor);
+      }
+
+      if (Candidates.size() <= 1)
+        continue;
+
+      Modified = true;
+      BasicBlock *Merge = Candidates[0];
+
+      auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
+      SmallVector<Value *, 1> Args = {MergeAddress};
+      IRBuilder<> Builder(&BB);
+      Builder.SetInsertPoint(BB.getTerminator());
+      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+    }
+
+    return Modified;
+  }
+
+  // Gather all the exit nodes for the construct header by |Header| and
+  // containing the blocks |Construct|.
+  std::vector<Edge> getExitsFrom(const BlockSet &Construct,
+                                 BasicBlock &Header) {
+    std::vector<Edge> Output;
+    visit(Header, [&](BasicBlock *Item) {
+      if (Construct.count(Item) == 0)
+        return false;
+
+      for (BasicBlock *Successor : successors(Item)) {
+        if (Construct.count(Successor) == 0)
+          Output.emplace_back(Item, Successor);
+      }
+      return true;
+    });
+
+    return Output;
+  }
+
+  // Build a divergent construct tree searching from |BB|.
+  // If |Parent| is not null, this tree is attached to the parent's tree.
+  void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
+                                   BasicBlock *BB, DivergentConstruct *Parent) {
+    if (Visited.count(BB) != 0)
+      return;
+    Visited.insert(BB);
+
+    auto MIS = getMergeInstructions(*BB);
+    if (MIS.size() == 0) {
+      for (BasicBlock *Successor : successors(BB))
+        constructDivergentConstruct(Visited, S, Successor, Parent);
+      return;
+    }
+
+    assert(MIS.size() == 1);
+    Instruction *MI = MIS[0];
+
+    BasicBlock *Merge = getDesignatedMergeBlock(MI);
+    BasicBlock *Continue = getDesignatedContinueBlock(MI);
+
+    auto Output = std::make_unique<DivergentConstruct>();
+    Output->Header = BB;
+    Output->Merge = Merge;
+    Output->Continue = Continue;
+    Output->Parent = Parent;
+
+    constructDivergentConstruct(Visited, S, Merge, Parent);
+    if (Continue)
+      constructDivergentConstruct(Visited, S, Continue, Output.get());
+
+    for (BasicBlock *Successor : successors(BB))
+      constructDivergentConstruct(Visited, S, Successor, Output.get());
+
+    if (Parent)
+      Parent->Children.emplace_back(std::move(Output));
+  }
+
+  // Returns the blocks belonging to the divergent construct |Node|.
+  BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
+    assert(Node->Header && Node->Merge);
+
+    if (Node->Continue) {
+      auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
+      return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
+    }
+
+    auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
+    return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
+  }
+
+  // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
+  // spec.
+  void fixupConstruct(Splitter &S, DivergentConstruct *Node) {
+    for (auto &Child : Node->Children)
+      fixupConstruct(S, Child.get());
+
+    // This construct is the root construct. Does not represent any real
+    // construct, just a way to access the first level of the forest.
+    if (Node->Parent == nullptr)
+      return;
+
+    // This node's parent is the root. Meaning this is a top-level construct.
+    // There can be multiple exists, but all are guaranteed to exit at most 1
+    // construct since we are at first level.
+    if (Node->Parent->Header == nullptr)
+      return;
+
+    // Health check for the structure.
+    assert(Node->Header && Node->Merge);
+    assert(Node->Parent->Header && Node->Parent->Merge);
+
+    BlockSet ConstructBlocks = getConstructBlocks(S, Node);
+    auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
+
+    //  No edges exiting the construct.
+    if (Edges.size() < 1)
+      return;
+
+    bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
+                      Node->Merge == Node->Parent->Continue;
+    // BasicBlock *Target = Edges[0].second;
+    for (auto &[Src, Dst] : Edges) {
+      // - Breaking from a selection construct: S is a selection construct, S is
+      // the innermost structured
+      //   control-flow construct containing A, and B is the merge block for S
+      // - Breaking from the innermost loop: S is the innermost loop construct
+      // containing A,
+      //   and B is the merge block for S
+      if (Node->Merge == Dst)
+        continue;
+
+      // Entering the innermost loop’s continue construct: S is the innermost
+      // loop construct containing A, and B is the continue target for S
+      if (Node->Continue == Dst)
+        continue;
+
+      // TODO: what about cases branching to another case in the switch? Seems
+      // to work, but need to double check.
+      HasBadEdge = true;
+    }
+
+    if (!HasBadEdge)
+      return;
+
+    // Create a single exit node gathering all exit edges.
+    BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
+
+    // Fixup this construct's merge node to point to the new exit.
+    // Note: this algorithm fixes inner-most divergence construct first. So
+    // recursive structures sharing a single merge node are fixed from the
+    // inside toward the outside.
+    auto MergeInstructions = getMergeInstructions(*Node->Header);
+    assert(MergeInstructions.size() == 1);
+    Instruction *I = MergeInstructions[0];
+    BlockAddress *BA = cast<BlockAddress>(I->getOperand(0));
+    if (BA->getBasicBlock() == Node->Merge) {
+      auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit);
+      I->setOperand(0, MergeAddress);
+    }
+
+    // Clean up of the possible dangling BockAddr operands to prevent MIR
+    // comments about "address of removed block taken".
+    if (!BA->isConstantUsed())
+      BA->destroyConstant();
+
+    Node->Merge = NewExit;
+    // Regenerate the dom trees.
+    S.invalidate();
+  }
+
+  bool splitCriticalEdges(Function &F) {
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    Splitter S(F, LI);
+
+    DivergentConstruct Root;
+    BlockSet Visited;
+    constructDivergentConstruct(Visited, S, &*F.begin(), &Root);
+    fixupConstruct(S, &Root);
+
+    return true;
----------------
Keenuts wrote:

right, fixed!

https://github.com/llvm/llvm-project/pull/107408


More information about the cfe-commits mailing list