[llvm] a5641f1 - [SPIR-V] Add pass to merge convergence region exit targets (#92531)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 02:35:58 PDT 2024
Author: Nathan Gauër
Date: 2024-06-03T11:35:55+02:00
New Revision: a5641f106affc3afb899eee42eb40c2ded81f411
URL: https://github.com/llvm/llvm-project/commit/a5641f106affc3afb899eee42eb40c2ded81f411
DIFF: https://github.com/llvm/llvm-project/commit/a5641f106affc3afb899eee42eb40c2ded81f411.diff
LOG: [SPIR-V] Add pass to merge convergence region exit targets (#92531)
The structurizer required regions to be SESE: single entry, single exit.
This new pass transforms multiple-exit regions into single-exit regions.
```
+---+
| A |
+---+
/ \
+---+ +---+
| B | | C | A, B & C belongs to the same convergence region.
+---+ +---+
| |
+---+ +---+
| D | | E | C & D belongs to the parent convergence region.
+---+ +---+ This means B & C are the exit blocks of the region.
\ / And D & E the targets of those exits.
\ /
|
+---+
| F |
+---+
```
This pass would assign one value per exit target:
B = 0
C = 1
Then, create one variable per exit block (B, C), and assign it to the
correct value: in B, the variable will have the value 0, and in C, the
value 1.
Then, we'd create a new block H, with a PHI node to gather those 2
variables, and a switch, to route to the correct target.
Finally, the branches in B and C are updated to exit to this new block.
```
+---+
| A |
+---+
/ \
+---+ +---+
| B | | C |
+---+ +---+
\ /
+---+
| H |
+---+
/ \
+---+ +---+
| D | | E |
+---+ +---+
\ /
\ /
|
+---+
| F |
+---+
```
Note: the variable is set depending on the condition used to branch. If
B's terminator was conditional, the variable would be set using a
SELECT.
All internal edges of a region are left intact, only exiting edges are
updated.
---------
Signed-off-by: Nathan Gauër <brioche at google.com>
Added:
llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll
Modified:
llvm/lib/Target/SPIRV/CMakeLists.txt
llvm/lib/Target/SPIRV/SPIRV.h
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index fe09d5903045c..14647e92f5d08 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -25,6 +25,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVInstrInfo.cpp
SPIRVInstructionSelector.cpp
SPIRVStripConvergentIntrinsics.cpp
+ SPIRVMergeRegionExitTargets.cpp
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index fb8580cd47c01..e597a1dc8dc06 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
class RegisterBankInfo;
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index ffbd1e17bad5e..5ef0be1cab722 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -151,6 +151,16 @@ class SPIRVEmitIntrinsics
ModulePass::getAnalysisUsage(AU);
}
};
+
+bool isConvergenceIntrinsic(const Instruction *I) {
+ const auto *II = dyn_cast<IntrinsicInst>(I);
+ if (!II)
+ return false;
+
+ return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
+ II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
+ II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
+}
} // namespace
char SPIRVEmitIntrinsics::ID = 0;
@@ -1353,6 +1363,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
Worklist.push_back(&I);
for (auto &I : Worklist) {
+ // Don't emit intrinsincs for convergence intrinsics.
+ if (isConvergenceIntrinsic(I))
+ continue;
+
insertAssignPtrTypeIntrs(I, B);
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
@@ -1371,6 +1385,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
I = visit(*I);
if (!I)
continue;
+
+ // Don't emit intrinsics for convergence operations.
+ if (isConvergenceIntrinsic(I))
+ continue;
+
processInstrAfterVisit(I, B);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 7c9b84a48a2a7..dedfd5e6e32db 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -617,7 +617,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
"$res = OpPhi $type $var0 $block0">;
def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
- "OpLoopMerge $merge $merge $continue $lc">;
+ "OpLoopMerge $merge $continue $lc">;
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
"OpSelectionMerge $merge $sc">;
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
new file mode 100644
index 0000000000000..2cdeb32579038
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -0,0 +1,284 @@
+//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Merge the multiple exit targets of a convergence region into a single block.
+// Each exit target will be assigned a constant value, and a phi node + switch
+// will allow the new exit target to re-route to the correct basic block.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
+
+class SPIRVMergeRegionExitTargets : public FunctionPass {
+public:
+ static char ID;
+
+ SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
+ initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
+ };
+
+ // Gather all the successors of |BB|.
+ // This function asserts if the terminator neither a branch, switch or return.
+ std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+ std::unordered_set<BasicBlock *> output;
+ auto *T = BB->getTerminator();
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+ output.insert(BI->getSuccessor(0));
+ if (BI->isConditional())
+ output.insert(BI->getSuccessor(1));
+ return output;
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(T)) {
+ output.insert(SI->getDefaultDest());
+ for (auto &Case : SI->cases())
+ output.insert(Case.getCaseSuccessor());
+ return output;
+ }
+
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return output;
+
+ assert(false && "Unhandled terminator type.");
+ return output;
+ }
+
+ /// Create a value in BB set to the value associated with the branch the block
+ /// terminator will take.
+ llvm::Value *createExitVariable(
+ BasicBlock *BB,
+ const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+ auto *T = BB->getTerminator();
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return nullptr;
+
+ IRBuilder<> Builder(BB);
+ Builder.SetInsertPoint(T);
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+
+ BasicBlock *LHSTarget = BI->getSuccessor(0);
+ BasicBlock *RHSTarget =
+ BI->isConditional() ? BI->getSuccessor(1) : nullptr;
+
+ Value *LHS = TargetToValue.count(LHSTarget) != 0
+ ? TargetToValue.at(LHSTarget)
+ : nullptr;
+ Value *RHS = TargetToValue.count(RHSTarget) != 0
+ ? TargetToValue.at(RHSTarget)
+ : nullptr;
+
+ if (LHS == nullptr || RHS == nullptr)
+ return LHS == nullptr ? RHS : LHS;
+ return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
+ }
+
+ // TODO: add support for switch cases.
+ assert(false && "Unhandled terminator type.");
+ }
+
+ /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
+ void replaceBranchTargets(BasicBlock *BB,
+ const std::unordered_set<BasicBlock *> ToReplace,
+ BasicBlock *NewTarget) {
+ auto *T = BB->getTerminator();
+ if (auto *RI = dyn_cast<ReturnInst>(T))
+ return;
+
+ if (auto *BI = dyn_cast<BranchInst>(T)) {
+ for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+ if (ToReplace.count(BI->getSuccessor(i)) != 0)
+ BI->setSuccessor(i, NewTarget);
+ }
+ return;
+ }
+
+ if (auto *SI = dyn_cast<SwitchInst>(T)) {
+ for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+ if (ToReplace.count(SI->getSuccessor(i)) != 0)
+ SI->setSuccessor(i, NewTarget);
+ }
+ return;
+ }
+
+ assert(false && "Unhandled terminator type.");
+ }
+
+ // Run the pass on the given convergence region, ignoring the sub-regions.
+ // Returns true if the CFG changed, false otherwise.
+ bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
+ const SPIRV::ConvergenceRegion *CR) {
+ // Gather all the exit targets for this region.
+ std::unordered_set<BasicBlock *> ExitTargets;
+ for (BasicBlock *Exit : CR->Exits) {
+ for (BasicBlock *Target : gatherSuccessors(Exit)) {
+ if (CR->Blocks.count(Target) == 0)
+ ExitTargets.insert(Target);
+ }
+ }
+
+ // If we have zero or one exit target, nothing do to.
+ if (ExitTargets.size() <= 1)
+ return false;
+
+ // Create the new single exit target.
+ auto F = CR->Entry->getParent();
+ auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
+ IRBuilder<> Builder(NewExitTarget);
+
+ // CodeGen output needs to be stable. Using the set as-is would order
+ // the targets
diff erently depending on the allocation pattern.
+ // Sorting per basic-block ordering in the function.
+ std::vector<BasicBlock *> SortedExitTargets;
+ std::vector<BasicBlock *> SortedExits;
+ for (BasicBlock &BB : *F) {
+ if (ExitTargets.count(&BB) != 0)
+ SortedExitTargets.push_back(&BB);
+ if (CR->Exits.count(&BB) != 0)
+ SortedExits.push_back(&BB);
+ }
+
+ // Creating one constant per distinct exit target. This will be route to the
+ // correct target.
+ std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+ for (BasicBlock *Target : SortedExitTargets)
+ TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+
+ // Creating one variable per exit node, set to the constant matching the
+ // targeted external block.
+ std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
+ for (auto Exit : SortedExits) {
+ llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+ ExitToVariable.emplace_back(std::make_pair(Exit, Value));
+ }
+
+ // Gather the correct value depending on the exit we came from.
+ llvm::PHINode *node =
+ Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
+ for (auto [BB, Value] : ExitToVariable) {
+ node->addIncoming(Value, BB);
+ }
+
+ // Creating the switch to jump to the correct exit target.
+ std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
+ TargetToValue.begin(), TargetToValue.end());
+ llvm::SwitchInst *Sw =
+ Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
+ for (size_t i = 1; i < CasesList.size(); i++)
+ Sw->addCase(CasesList[i].second, CasesList[i].first);
+
+ // Fix exit branches to redirect to the new exit.
+ for (auto Exit : CR->Exits)
+ replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
+
+ return true;
+ }
+
+ /// Run the pass on the given convergence region and sub-regions (DFS).
+ /// Returns true if a region/sub-region was modified, false otherwise.
+ /// This returns as soon as one region/sub-region has been modified.
+ bool runOnConvergenceRegion(LoopInfo &LI,
+ const SPIRV::ConvergenceRegion *CR) {
+ for (auto *Child : CR->Children)
+ if (runOnConvergenceRegion(LI, Child))
+ return true;
+
+ return runOnConvergenceRegionNoRecurse(LI, CR);
+ }
+
+#if !NDEBUG
+ /// Validates each edge exiting the region has the same destination basic
+ /// block.
+ void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
+ for (auto *Child : CR->Children)
+ validateRegionExits(Child);
+
+ std::unordered_set<BasicBlock *> ExitTargets;
+ for (auto *Exit : CR->Exits) {
+ auto Set = gatherSuccessors(Exit);
+ for (auto *BB : Set) {
+ if (CR->Blocks.count(BB) == 0)
+ ExitTargets.insert(BB);
+ }
+ }
+
+ assert(ExitTargets.size() <= 1);
+ }
+#endif
+
+ virtual bool runOnFunction(Function &F) override {
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ const auto *TopLevelRegion =
+ getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+ .getRegionInfo()
+ .getTopLevelRegion();
+
+ // FIXME: very inefficient method: each time a region is modified, we bubble
+ // back up, and recompute the whole convergence region tree. Once the
+ // algorithm is completed and test coverage good enough, rewrite this pass
+ // to be efficient instead of simple.
+ bool modified = false;
+ while (runOnConvergenceRegion(LI, TopLevelRegion)) {
+ TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+ .getRegionInfo()
+ .getTopLevelRegion();
+ modified = true;
+ }
+
+#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
+ validateRegionExits(TopLevelRegion);
+#endif
+ return modified;
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+ FunctionPass::getAnalysisUsage(AU);
+ }
+};
+} // namespace llvm
+
+char SPIRVMergeRegionExitTargets::ID = 0;
+
+INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+ "SPIRV split region exit blocks", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
+
+INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+ "SPIRV split region exit blocks", false, false)
+
+FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
+ return new SPIRVMergeRegionExitTargets();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index ae8baa3f11913..a6823a8ba3230 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -164,6 +164,11 @@ void SPIRVPassConfig::addIRPasses() {
// - all loop exits are dominated by the loop pre-header.
// - loops have a single back-edge.
addPass(createLoopSimplifyPass());
+
+ // 2. Merge the convergence region exit nodes into one. After this step,
+ // regions are single-entry, single-exit. This will help determine the
+ // correct merge block.
+ addPass(createSPIRVMergeRegionExitTargetsPass());
}
TargetPassConfig::addIRPasses();
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
new file mode 100644
index 0000000000000..b3fcdc978625f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
@@ -0,0 +1,84 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK: %[[#entry:]] = OpLabel
+; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK: OpBranch %[[#while_cond:]]
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+ %idx = alloca i32, align 4
+ store i32 0, ptr %idx, align 4
+ br label %while.cond
+
+; CHECK: %[[#while_cond]] = OpLabel
+; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+ %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %2 = load i32, ptr %idx, align 4
+ %cmp = icmp ne i32 %2, 10
+ br i1 %cmp, label %while.body, label %while.end
+
+; CHECK: %[[#while_body]] = OpLabel
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK: OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]]
+while.body:
+ %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %3, ptr %idx, align 4
+ %4 = load i32, ptr %idx, align 4
+ %cmp1 = icmp eq i32 %4, 0
+ br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK: %[[#if_then:]] = OpLabel
+; CHECK: OpBranch %[[#while_end:]]
+if.then:
+ br label %while.end
+
+; CHECK: %[[#if_end]] = OpLabel
+; CHECK: OpBranch %[[#while_cond]]
+if.end:
+ br label %while.cond
+
+; CHECK: %[[#while_end_loopexit:]] = OpLabel
+; CHECK: OpBranch %[[#while_end]]
+
+; CHECK: %[[#while_end]] = OpLabel
+; CHECK: OpReturn
+while.end:
+ ret void
+
+; CHECK: %[[#new_end]] = OpLabel
+; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]]
+; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
new file mode 100644
index 0000000000000..a67c58fdd5749
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
@@ -0,0 +1,94 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK: %[[#entry:]] = OpLabel
+; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK: OpBranch %[[#while_cond:]]
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+ %idx = alloca i32, align 4
+ store i32 0, ptr %idx, align 4
+ br label %while.cond
+
+; CHECK: %[[#while_cond]] = OpLabel
+; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+ %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %2 = load i32, ptr %idx, align 4
+ %cmp = icmp ne i32 %2, 10
+ br i1 %cmp, label %while.body, label %while.end
+
+; CHECK: %[[#while_body]] = OpLabel
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK: OpBranchConditional %[[#cmp1]] %[[#if_then:]] %[[#if_end:]]
+while.body:
+ %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %3, ptr %idx, align 4
+ %4 = load i32, ptr %idx, align 4
+ %cmp1 = icmp eq i32 %4, 0
+ br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK: %[[#if_then:]] = OpLabel
+; CHECK-NEXT: OpBranch %[[#tail:]]
+if.then:
+ br label %tail
+
+; CHECK: %[[#tail:]] = OpLabel
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK: OpBranch %[[#new_end:]]
+tail:
+ %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %5, ptr %idx, align 4
+ br label %while.end
+
+; CHECK: %[[#if_end]] = OpLabel
+; CHECK: OpBranch %[[#while_cond]]
+if.end:
+ br label %while.cond
+
+; CHECK: %[[#while_end_loopexit:]] = OpLabel
+; CHECK: OpBranch %[[#while_end:]]
+
+; CHECK: %[[#while_end]] = OpLabel
+; CHECK: OpReturn
+while.end:
+ ret void
+
+; CHECK: %[[#new_end]] = OpLabel
+; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]]
+; CHECK: OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
new file mode 100644
index 0000000000000..32a97553df05e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
@@ -0,0 +1,103 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG: %[[#int_2:]] = OpConstant %[[#int_ty]] 2
+; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK: %[[#entry:]] = OpLabel
+; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK: OpBranch %[[#while_cond:]]
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+ %idx = alloca i32, align 4
+ store i32 0, ptr %idx, align 4
+ br label %while.cond
+
+; CHECK: %[[#while_cond]] = OpLabel
+; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+ %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %2 = load i32, ptr %idx, align 4
+ %cmp = icmp ne i32 %2, 10
+ br i1 %cmp, label %while.body, label %while.end
+
+; CHECK: %[[#while_body]] = OpLabel
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK: OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]]
+while.body:
+ %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %3, ptr %idx, align 4
+ %4 = load i32, ptr %idx, align 4
+ %cmp1 = icmp eq i32 %4, 0
+ br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK: %[[#if_then:]] = OpLabel
+; CHECK: OpBranch %[[#while_end:]]
+if.then:
+ br label %while.end
+
+; CHECK: %[[#if_end]] = OpLabel
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT: %[[#cmp2:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK: OpBranchConditional %[[#cmp2]] %[[#new_end]] %[[#if_end2:]]
+if.end:
+ %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %5, ptr %idx, align 4
+ %6 = load i32, ptr %idx, align 4
+ %cmp2 = icmp eq i32 %6, 0
+ br i1 %cmp2, label %if.then2, label %if.end2
+
+; CHECK: %[[#if_then2:]] = OpLabel
+; CHECK: OpBranch %[[#while_end:]]
+if.then2:
+ br label %while.end
+
+; CHECK: %[[#if_end2]] = OpLabel
+; CHECK: OpBranch %[[#while_cond:]]
+if.end2:
+ br label %while.cond
+
+; CHECK: %[[#while_end_loopexit:]] = OpLabel
+; CHECK: OpBranch %[[#while_end]]
+
+; CHECK: %[[#while_end]] = OpLabel
+; CHECK: OpReturn
+while.end:
+ ret void
+
+; CHECK: %[[#new_end]] = OpLabel
+; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]]
+; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll
new file mode 100644
index 0000000000000..a8bf4fb0db989
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll
@@ -0,0 +1,49 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK: %[[#entry:]] = OpLabel
+; CHECK: OpBranch %[[#while_cond:]]
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+ %idx = alloca i32, align 4
+ store i32 -1, ptr %idx, align 4
+ br label %while.cond
+
+; CHECK: %[[#while_cond]] = OpLabel
+; CHECK: OpBranchConditional %[[#cond:]] %[[#while_body:]] %[[#while_end:]]
+while.cond:
+ %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %2 = load i32, ptr %idx, align 4
+ %cmp = icmp ne i32 %2, 0
+ br i1 %cmp, label %while.body, label %while.end
+
+; CHECK: %[[#while_body]] = OpLabel
+; CHECK: OpBranch %[[#while_cond]]
+while.body:
+ %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+ store i32 %3, ptr %idx, align 4
+ br label %while.cond
+
+ ; CHECK: %[[#while_end]] = OpLabel
+; CHECK-NEXT: OpReturn
+while.end:
+ ret void
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
More information about the llvm-commits
mailing list