[llvm] [SPIR-V] Add pass to merge convergence region exit targets (PR #92531)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Mon May 27 08:43:54 PDT 2024


================
@@ -0,0 +1,287 @@
+//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Merge the multiple exit targets of a convergence region into a single block.
+// Each exit target will be assigned a constant value, and a phi node + switch
+// will allow the new exit target to re-route to the correct basic block.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
+} // namespace llvm
+
+namespace llvm {
+
+class SPIRVMergeRegionExitTargets : public FunctionPass {
+public:
+  static char ID;
+
+  SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
+    initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
+  };
+
+  // Gather all the successors of |BB|.
+  // This function asserts if the terminator neither a branch, switch or return.
+  std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+    std::unordered_set<BasicBlock *> output;
+    auto *T = BB->getTerminator();
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      output.insert(BI->getSuccessor(0));
+      if (BI->isConditional())
+        output.insert(BI->getSuccessor(1));
+      return output;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      output.insert(SI->getDefaultDest());
+      for (auto &Case : SI->cases())
+        output.insert(Case.getCaseSuccessor());
+      return output;
+    }
+
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return output;
+
+    assert(false && "Unhandled terminator type.");
+    return output;
+  }
+
+  /// Create a value in BB set to the value associated with the branch the block
+  /// terminator will take.
+  llvm::Value *createExitVariable(
+      BasicBlock *BB,
+      const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return nullptr;
+
+    IRBuilder<> Builder(BB);
+    Builder.SetInsertPoint(T);
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+
+      BasicBlock *LHSTarget = BI->getSuccessor(0);
+      BasicBlock *RHSTarget =
+          BI->isConditional() ? BI->getSuccessor(1) : nullptr;
+
+      Value *LHS = TargetToValue.count(LHSTarget) != 0
+                       ? TargetToValue.at(LHSTarget)
+                       : nullptr;
+      Value *RHS = TargetToValue.count(RHSTarget) != 0
+                       ? TargetToValue.at(RHSTarget)
+                       : nullptr;
+
+      if (LHS == nullptr || RHS == nullptr)
+        return LHS == nullptr ? RHS : LHS;
+      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
+    }
+
+    // TODO: add support for switch cases.
+    assert(false && "Unhandled terminator type.");
+  }
+
+  /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
+  void replaceBranchTargets(BasicBlock *BB,
+                            const std::unordered_set<BasicBlock *> ToReplace,
+                            BasicBlock *NewTarget) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return;
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+        if (ToReplace.count(BI->getSuccessor(i)) != 0)
+          BI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+        if (ToReplace.count(SI->getSuccessor(i)) != 0)
+          SI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    assert(false && "Unhandled terminator type.");
+  }
+
+  // Run the pass on the given convergence region, ignoring the sub-regions.
+  // Returns true if the CFG changed, false otherwise.
+  bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
+                                       const SPIRV::ConvergenceRegion *CR) {
+    // Gather all the exit targets for this region.
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (BasicBlock *Exit : CR->Exits) {
+      for (BasicBlock *Target : gatherSuccessors(Exit)) {
+        if (CR->Blocks.count(Target) == 0)
+          ExitTargets.insert(Target);
+      }
+    }
+
+    // If we have zero or one exit target, nothing do to.
+    if (ExitTargets.size() <= 1)
+      return false;
+
+    // Create the new single exit target.
+    auto F = CR->Entry->getParent();
+    auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
+    IRBuilder<> Builder(NewExitTarget);
+
+    // CodeGen output needs to be stable. Using the set as-is would order
+    // the targets differently depending on the allocation pattern.
+    // Sorting per basic-block ordering in the function.
+    std::vector<BasicBlock *> SortedExitTargets;
+    std::vector<BasicBlock *> SortedExits;
+    for (BasicBlock &BB : *F) {
+      if (ExitTargets.count(&BB) != 0)
+        SortedExitTargets.push_back(&BB);
+      if (CR->Exits.count(&BB) != 0)
+        SortedExits.push_back(&BB);
+    }
+
+    // Creating one constant per distinct exit target. This will be route to the
+    // correct target.
+    std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+    for (BasicBlock *Target : SortedExitTargets)
+      TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+
+    // Creating one variable per exit node, set to the constant matching the
+    // targeted external block.
+    std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
+    for (auto Exit : SortedExits) {
+      llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+      ExitToVariable.emplace_back(std::make_pair(Exit, Value));
+    }
+
+    // Gather the correct value depending on the exit we came from.
+    llvm::PHINode *node =
+        Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
+    for (auto [BB, Value] : ExitToVariable) {
+      node->addIncoming(Value, BB);
+    }
+
+    // Creating the switch to jump to the correct exit target.
+    std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
+        TargetToValue.begin(), TargetToValue.end());
+    llvm::SwitchInst *Sw =
+        Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
+    for (size_t i = 1; i < CasesList.size(); i++)
+      Sw->addCase(CasesList[i].second, CasesList[i].first);
+
+    // Fix exit branches to redirect to the new exit.
+    for (auto Exit : CR->Exits)
+      replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
+
+    return true;
+  }
+
+  /// Run the pass on the given convergence region and sub-regions (DFS).
+  /// Returns true if a region/sub-region was modified, false otherwise.
+  /// This returns as soon as one region/sub-region has been modified.
+  bool runOnConvergenceRegion(LoopInfo &LI,
+                              const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      if (runOnConvergenceRegion(LI, Child))
+        return true;
+
+    return runOnConvergenceRegionNoRecurse(LI, CR);
+  }
+
+#if !NDEBUG
+  /// Validates each edge exiting the region has the same destination basic
+  /// block.
+  void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      validateRegionExits(Child);
+
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (auto *Exit : CR->Exits) {
+      auto Set = gatherSuccessors(Exit);
+      for (auto *BB : Set) {
+        if (CR->Blocks.count(BB) == 0)
+          ExitTargets.insert(BB);
+      }
+    }
+
+    assert(ExitTargets.size() <= 1);
+  }
+#endif
+
+  virtual bool runOnFunction(Function &F) override {
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    const auto *TopLevelRegion =
+        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+            .getRegionInfo()
+            .getTopLevelRegion();
+
+    // FIXME: very inefficient method: each time a region is modified, we bubble
+    // back up, and recompute the whole convergence region tree. Once the
+    // algorithm is completed and test coverage good enough, rewrite this pass
+    // to be efficient instead of simple.
+    bool modified = false;
+    while (runOnConvergenceRegion(LI, TopLevelRegion)) {
+      TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+                           .getRegionInfo()
+                           .getTopLevelRegion();
+      modified = true;
+    }
+
+#if !NDEBUG
----------------
Keenuts wrote:

Seems like some passes enable it on both NDEBUG and EXPENSIVE_CHECKS, and some only EXPENSIVE_CHECKS. Since this is not that costly to do, I'll enable it in both cases (at least until this structurizer is more stable/robust).

https://github.com/llvm/llvm-project/pull/92531


More information about the llvm-commits mailing list