[llvm] [SPIR-V] Add pass to merge convergence region exit targets (PR #92531)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Fri May 17 05:28:30 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/92531

>From 5ae8f2dc2d4d4f51d24b2d19b51fcffee08d82b8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 15 May 2024 17:10:13 +0200
Subject: [PATCH] [SPIR-V] Add pass to merge convergence region exit targets
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The structurizer required regions to be SESE: single entry, single
exit.
This new pass transforms multiple-exit regions into single-exit regions.

```
      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |  A, B & C belongs to the same convergence region.
   +---+ +---+
     |     |
   +---+ +---+
   | D | | E |  C & D belongs to the parent convergence region.
   +---+ +---+  This means B & C are the exit blocks of the region.
      \   /     And D & E the targets of those exits.
       \ /
        |
      +---+
      | F |
      +---+
```

This pass would assign one value per exit target:
B = 0
C = 1

Then, create one variable per exit block (B, C), and assign it
to the correct value: in B, the variable will have the value 0,
and in C, the value 1.

Then, we'd create a new block H, with a PHI node to gather those
2 variables, and a switch, to route to the correct target.

Finally, the branches in B and C are updated to exit to this new block.

```
      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |
   +---+ +---+
      \   /
      +---+
      | H |
      +---+
      /   \
   +---+ +---+
   | D | | E |
   +---+ +---+
      \   /
       \ /
        |
      +---+
      | F |
      +---+
```

Note: the variable is set depending on the condition used to branch.
If B's terminator was conditional, the variable would be set using a
SELECT.
All internal edges of a region are left intact, only exiting edges are
updated.

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/lib/Target/SPIRV/CMakeLists.txt          |   1 +
 llvm/lib/Target/SPIRV/SPIRV.h                 |   1 +
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  23 ++
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |   2 +-
 .../SPIRV/SPIRVMergeRegionExitTargets.cpp     | 290 ++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp  |   1 +
 .../SPIRV/structurizer/merge-exit-break.ll    |  84 +++++
 .../merge-exit-convergence-in-break.ll        |  94 ++++++
 .../structurizer/merge-exit-multiple-break.ll | 103 +++++++
 .../merge-exit-simple-white-identity.ll       |  49 +++
 10 files changed, 647 insertions(+), 1 deletion(-)
 create mode 100644 llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
 create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll

diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7001ac382f41c..35a463a89ec64 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVInstrInfo.cpp
   SPIRVInstructionSelector.cpp
   SPIRVStripConvergentIntrinsics.cpp
+  SPIRVMergeRegionExitTargets.cpp
   SPIRVISelLowering.cpp
   SPIRVLegalizerInfo.cpp
   SPIRVMCInstLower.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index fb8580cd47c01..e597a1dc8dc06 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
 class RegisterBankInfo;
 
 ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+FunctionPass *createSPIRVMergeRegionExitTargetsPass();
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 32df2403dfe52..057cdd7a3ee2c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics
     ModulePass::getAnalysisUsage(AU);
   }
 };
+
+bool isConvergenceIntrinsic(const Instruction *I) {
+  const auto *II = dyn_cast<IntrinsicInst>(I);
+  if (!II)
+    return false;
+
+  return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
+         II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
+         II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
+}
 } // namespace
 
 char SPIRVEmitIntrinsics::ID = 0;
@@ -1074,6 +1084,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
 
 void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
                                                    IRBuilder<> &B) {
+  // Don't assign types to LLVM tokens.
+  if (isConvergenceIntrinsic(I))
+    return;
+
   reportFatalOnTokenType(I);
   if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
       isa<BitCastInst>(I))
@@ -1092,6 +1106,10 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
+  // Don't assign types to LLVM tokens.
+  if (isConvergenceIntrinsic(I))
+    return;
+
   reportFatalOnTokenType(I);
   Type *Ty = I->getType();
   if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
@@ -1319,6 +1337,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     I = visit(*I);
     if (!I)
       continue;
+
+    // Don't emit intrinsics for convergence operations.
+    if (isConvergenceIntrinsic(I))
+      continue;
+
     processInstrAfterVisit(I, B);
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 151d0ec1fe569..6634481daf12e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -615,7 +615,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
 def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
                   "$res = OpPhi $type $var0 $block0">;
 def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
-                  "OpLoopMerge $merge $merge $continue $lc">;
+                  "OpLoopMerge $merge $continue $lc">;
 def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
                   "OpSelectionMerge $merge $sc">;
 def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
new file mode 100644
index 0000000000000..13781e24f0d42
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -0,0 +1,290 @@
+//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Merge the multiple exit targets of a convergence region into a single block.
+// Each exit target will be assigned a constant value, and a phi node + switch
+// will allow the new exit target to re-route to the correct basic block.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
+} // namespace llvm
+
+namespace llvm {
+
+class SPIRVMergeRegionExitTargets : public FunctionPass {
+public:
+  static char ID;
+
+  SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
+    initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
+  };
+
+  // Gather all the successors of |BB|.
+  // This function asserts if the terminator neither a branch, switch or return.
+  std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+    std::unordered_set<BasicBlock *> output;
+    auto *T = BB->getTerminator();
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      output.insert(BI->getSuccessor(0));
+      if (BI->isConditional())
+        output.insert(BI->getSuccessor(1));
+      return output;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      output.insert(SI->getDefaultDest());
+      for (auto &Case : SI->cases()) {
+        output.insert(Case.getCaseSuccessor());
+      }
+      return output;
+    }
+
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return output;
+
+    assert(false && "Unhandled terminator type.");
+    return output;
+  }
+
+  /// Create a value in BB set to the value associated with the branch the block
+  /// terminator will take.
+  llvm::Value *createExitVariable(
+      BasicBlock *BB,
+      const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T)) {
+      return nullptr;
+    }
+
+    IRBuilder<> Builder(BB);
+    Builder.SetInsertPoint(T);
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+
+      BasicBlock *LHSTarget = BI->getSuccessor(0);
+      BasicBlock *RHSTarget =
+          BI->isConditional() ? BI->getSuccessor(1) : nullptr;
+
+      Value *LHS = TargetToValue.count(LHSTarget) != 0
+                       ? TargetToValue.at(LHSTarget)
+                       : nullptr;
+      Value *RHS = TargetToValue.count(RHSTarget) != 0
+                       ? TargetToValue.at(RHSTarget)
+                       : nullptr;
+
+      if (LHS == nullptr || RHS == nullptr)
+        return LHS == nullptr ? RHS : LHS;
+      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
+    }
+
+    // TODO: add support for switch cases.
+    assert(false && "Unhandled terminator type.");
+  }
+
+  /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
+  void replaceBranchTargets(BasicBlock *BB,
+                            const std::unordered_set<BasicBlock *> ToReplace,
+                            BasicBlock *NewTarget) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return;
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+        if (ToReplace.count(BI->getSuccessor(i)) != 0)
+          BI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+        if (ToReplace.count(SI->getSuccessor(i)) != 0)
+          SI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    assert(false && "Unhandled terminator type.");
+  }
+
+  // Run the pass on the given convergence region, ignoring the sub-regions.
+  // Returns true if the CFG changed, false otherwise.
+  bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
+                                       const SPIRV::ConvergenceRegion *CR) {
+    // Gather all the exit targets for this region.
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (BasicBlock *Exit : CR->Exits) {
+      for (BasicBlock *Target : gatherSuccessors(Exit)) {
+        if (CR->Blocks.count(Target) == 0)
+          ExitTargets.insert(Target);
+      }
+    }
+
+    // If we have zero or one exit target, nothing do to.
+    if (ExitTargets.size() <= 1)
+      return false;
+
+    // Create the new single exit target.
+    auto F = CR->Entry->getParent();
+    auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
+    IRBuilder<> Builder(NewExitTarget);
+
+    // CodeGen output needs to be stable. Using the set as-is would order
+    // the targets differently depending on the allocation pattern.
+    // Sorting per basic-block ordering in the function.
+    std::vector<BasicBlock *> SortedExitTargets;
+    std::vector<BasicBlock *> SortedExits;
+    for (BasicBlock &BB : *F) {
+      if (ExitTargets.count(&BB) != 0)
+        SortedExitTargets.push_back(&BB);
+      if (CR->Exits.count(&BB) != 0)
+        SortedExits.push_back(&BB);
+    }
+
+    // Creating one constant per distinct exit target. This will be route to the
+    // correct target.
+    std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+    for (BasicBlock *Target : SortedExitTargets)
+      TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+
+    // Creating one variable per exit node, set to the constant matching the
+    // targeted external block.
+    std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
+    for (auto Exit : SortedExits) {
+      llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+      ExitToVariable.emplace_back(std::make_pair(Exit, Value));
+    }
+
+    // Gather the correct value depending on the exit we came from.
+    llvm::PHINode *node =
+        Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
+    for (auto [BB, Value] : ExitToVariable) {
+      node->addIncoming(Value, BB);
+    }
+
+    // Creating the switch to jump to the correct exit target.
+    std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
+        TargetToValue.begin(), TargetToValue.end());
+    llvm::SwitchInst *Sw =
+        Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
+    for (size_t i = 1; i < CasesList.size(); i++)
+      Sw->addCase(CasesList[i].second, CasesList[i].first);
+
+    // Fix exit branches to redirect to the new exit.
+    for (auto Exit : CR->Exits)
+      replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
+
+    return true;
+  }
+
+  /// Run the pass on the given convergence region and sub-regions (DFS).
+  /// Returns true if a region/sub-region was modified, false otherwise.
+  /// This returns as soon as one region/sub-region has been modified.
+  bool runOnConvergenceRegion(LoopInfo &LI,
+                              const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      if (runOnConvergenceRegion(LI, Child))
+        return true;
+
+    return runOnConvergenceRegionNoRecurse(LI, CR);
+  }
+
+#if !NDEBUG
+  /// Validates each edge exiting the region has the same destination basic
+  /// block.
+  void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      validateRegionExits(Child);
+
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (auto *Exit : CR->Exits) {
+      auto Set = gatherSuccessors(Exit);
+      for (auto *BB : Set) {
+        if (CR->Blocks.count(BB) == 0)
+          ExitTargets.insert(BB);
+      }
+    }
+
+    assert(ExitTargets.size() <= 1);
+  }
+#endif
+
+  virtual bool runOnFunction(Function &F) override {
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    const auto *TopLevelRegion =
+        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+            .getRegionInfo()
+            .getTopLevelRegion();
+
+    // FIXME: very inefficient method: each time a region is modified, we bubble
+    // back up, and recompute the whole convergence region tree. Once the
+    // algorithm is completed and test coverage good enough, rewrite this pass
+    // to be efficient instead of simple.
+    bool modified = false;
+    while (runOnConvergenceRegion(LI, TopLevelRegion)) {
+      TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+                           .getRegionInfo()
+                           .getTopLevelRegion();
+      modified = true;
+    }
+
+    F.dump();
+#if !NDEBUG
+    validateRegionExits(TopLevelRegion);
+#endif
+    return modified;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+    FunctionPass::getAnalysisUsage(AU);
+  }
+};
+} // namespace llvm
+
+char SPIRVMergeRegionExitTargets::ID = 0;
+
+INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+                      "SPIRV split region exit blocks", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
+
+INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+                    "SPIRV split region exit blocks", false, false)
+
+FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
+  return new SPIRVMergeRegionExitTargets();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index ae8baa3f11913..d0e51caf46e73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -164,6 +164,7 @@ void SPIRVPassConfig::addIRPasses() {
     //  - all loop exits are dominated by the loop pre-header.
     //  - loops have a single back-edge.
     addPass(createLoopSimplifyPass());
+    addPass(createSPIRVMergeRegionExitTargetsPass());
   }
 
   TargetPassConfig::addIRPasses();
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
new file mode 100644
index 0000000000000..b3fcdc978625f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
@@ -0,0 +1,84 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:                      OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:  %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG:   %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG:   %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG:  %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:     %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK:                  OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 0, ptr %idx, align 4
+  br label %while.cond
+
+; CHECK:   %[[#while_cond]] = OpLabel
+; CHECK:         %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK:         %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK:                      OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+  %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %2 = load i32, ptr %idx, align 4
+  %cmp = icmp ne i32 %2, 10
+  br i1 %cmp, label %while.body, label %while.end
+
+; CHECK:   %[[#while_body]] = OpLabel
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                 OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT:   %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK:                      OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]]
+while.body:
+  %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %3, ptr %idx, align 4
+  %4 = load i32, ptr %idx, align 4
+  %cmp1 = icmp eq i32 %4, 0
+  br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK:   %[[#if_then:]] = OpLabel
+; CHECK:                    OpBranch %[[#while_end:]]
+if.then:
+  br label %while.end
+
+; CHECK:   %[[#if_end]] = OpLabel
+; CHECK:                  OpBranch %[[#while_cond]]
+if.end:
+  br label %while.cond
+
+; CHECK:   %[[#while_end_loopexit:]] = OpLabel
+; CHECK:                               OpBranch %[[#while_end]]
+
+; CHECK:   %[[#while_end]] = OpLabel
+; CHECK:                     OpReturn
+while.end:
+  ret void
+
+; CHECK:   %[[#new_end]] = OpLabel
+; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]]
+; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
new file mode 100644
index 0000000000000..a67c58fdd5749
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
@@ -0,0 +1,94 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:                      OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:  %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG:   %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG:   %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG:  %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:     %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK:                  OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 0, ptr %idx, align 4
+  br label %while.cond
+
+; CHECK:   %[[#while_cond]] = OpLabel
+; CHECK:         %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK:         %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK:                      OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+  %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %2 = load i32, ptr %idx, align 4
+  %cmp = icmp ne i32 %2, 10
+  br i1 %cmp, label %while.body, label %while.end
+
+; CHECK:   %[[#while_body]] = OpLabel
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                 OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT:   %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK:                      OpBranchConditional %[[#cmp1]] %[[#if_then:]] %[[#if_end:]]
+while.body:
+  %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %3, ptr %idx, align 4
+  %4 = load i32, ptr %idx, align 4
+  %cmp1 = icmp eq i32 %4, 0
+  br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK:        %[[#if_then:]] = OpLabel
+; CHECK-NEXT:                    OpBranch %[[#tail:]]
+if.then:
+  br label %tail
+
+; CHECK:        %[[#tail:]] = OpLabel
+; CHECK-NEXT:       %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                    OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK:                         OpBranch %[[#new_end:]]
+tail:
+  %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %5, ptr %idx, align 4
+  br label %while.end
+
+; CHECK:   %[[#if_end]] = OpLabel
+; CHECK:                  OpBranch %[[#while_cond]]
+if.end:
+  br label %while.cond
+
+; CHECK:   %[[#while_end_loopexit:]] = OpLabel
+; CHECK:                               OpBranch %[[#while_end:]]
+
+; CHECK:   %[[#while_end]] = OpLabel
+; CHECK:                     OpReturn
+while.end:
+  ret void
+
+; CHECK:   %[[#new_end]] = OpLabel
+; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]]
+; CHECK:                   OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
new file mode 100644
index 0000000000000..32a97553df05e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
@@ -0,0 +1,103 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:                      OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:  %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG:   %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG:   %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG:   %[[#int_2:]] = OpConstant %[[#int_ty]] 2
+; CHECK-DAG:  %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:     %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK:                  OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 0, ptr %idx, align 4
+  br label %while.cond
+
+; CHECK:   %[[#while_cond]] = OpLabel
+; CHECK:         %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK:         %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK:                      OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+  %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %2 = load i32, ptr %idx, align 4
+  %cmp = icmp ne i32 %2, 10
+  br i1 %cmp, label %while.body, label %while.end
+
+; CHECK:   %[[#while_body]] = OpLabel
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                 OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT:   %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK:                      OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]]
+while.body:
+  %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %3, ptr %idx, align 4
+  %4 = load i32, ptr %idx, align 4
+  %cmp1 = icmp eq i32 %4, 0
+  br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK:   %[[#if_then:]] = OpLabel
+; CHECK:                    OpBranch %[[#while_end:]]
+if.then:
+  br label %while.end
+
+; CHECK:       %[[#if_end]] = OpLabel
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                 OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT:   %[[#cmp2:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK:                      OpBranchConditional %[[#cmp2]] %[[#new_end]] %[[#if_end2:]]
+if.end:
+  %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %5, ptr %idx, align 4
+  %6 = load i32, ptr %idx, align 4
+  %cmp2 = icmp eq i32 %6, 0
+  br i1 %cmp2, label %if.then2, label %if.end2
+
+; CHECK:   %[[#if_then2:]] = OpLabel
+; CHECK:                     OpBranch %[[#while_end:]]
+if.then2:
+  br label %while.end
+
+; CHECK:   %[[#if_end2]] = OpLabel
+; CHECK:                    OpBranch %[[#while_cond:]]
+if.end2:
+  br label %while.cond
+
+; CHECK:   %[[#while_end_loopexit:]] = OpLabel
+; CHECK:                               OpBranch %[[#while_end]]
+
+; CHECK:   %[[#while_end]] = OpLabel
+; CHECK:                     OpReturn
+while.end:
+  ret void
+
+; CHECK:   %[[#new_end]] = OpLabel
+; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]]
+; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll
new file mode 100644
index 0000000000000..a8bf4fb0db989
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll
@@ -0,0 +1,49 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 -1, ptr %idx, align 4
+  br label %while.cond
+
+; CHECK:   %[[#while_cond]] = OpLabel
+; CHECK:                      OpBranchConditional %[[#cond:]] %[[#while_body:]] %[[#while_end:]]
+while.cond:
+  %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %2 = load i32, ptr %idx, align 4
+  %cmp = icmp ne i32 %2, 0
+  br i1 %cmp, label %while.body, label %while.end
+
+; CHECK:   %[[#while_body]] = OpLabel
+; CHECK:                      OpBranch %[[#while_cond]]
+while.body:
+  %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %3, ptr %idx, align 4
+  br label %while.cond
+
+     ; CHECK:   %[[#while_end]] = OpLabel
+; CHECK-NEXT:                     OpReturn
+while.end:
+  ret void
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}



More information about the llvm-commits mailing list