[llvm] [SPIR-V] Fix flakiness during switch generation. (PR #95001)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 10 09:30:57 PDT 2024
https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/95001
>From 442fd1cc9596e808a47c5d5cb82aa92f1f1ccb41 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 10 Jun 2024 14:00:49 +0200
Subject: [PATCH] [SPIR-V] Fix flakiness during switch generation.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The case-list of the switches generated by this pass were not
"deterministic" (based on allocation patterns).
This is because the CaseList order relied on an unordered_set
order.
Using the sorted exit target list for those should solve the problem.
Signed-off-by: Nathan Gauër <brioche at google.com>
---
.../SPIRV/SPIRVMergeRegionExitTargets.cpp | 25 +++++++++++--------
.../SPIRV/structurizer/merge-exit-break.ll | 2 +-
.../merge-exit-convergence-in-break.ll | 2 +-
.../structurizer/merge-exit-multiple-break.ll | 2 +-
4 files changed, 17 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 2744c25d1bc75..52354281cdd7e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -17,6 +17,8 @@
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/CFG.h"
@@ -71,7 +73,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
/// terminator will take.
llvm::Value *createExitVariable(
BasicBlock *BB,
- const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+ const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
auto *T = BB->getTerminator();
if (isa<ReturnInst>(T))
return nullptr;
@@ -103,7 +105,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
void replaceBranchTargets(BasicBlock *BB,
- const std::unordered_set<BasicBlock *> ToReplace,
+ const SmallPtrSet<BasicBlock *, 4> &ToReplace,
BasicBlock *NewTarget) {
auto *T = BB->getTerminator();
if (isa<ReturnInst>(T))
@@ -133,7 +135,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
// Gather all the exit targets for this region.
- std::unordered_set<BasicBlock *> ExitTargets;
+ SmallPtrSet<BasicBlock *, 4> ExitTargets;
for (BasicBlock *Exit : CR->Exits) {
for (BasicBlock *Target : gatherSuccessors(Exit)) {
if (CR->Blocks.count(Target) == 0)
@@ -164,9 +166,10 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
// Creating one constant per distinct exit target. This will be route to the
// correct target.
- std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+ DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
for (BasicBlock *Target : SortedExitTargets)
- TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+ TargetToValue.insert(
+ std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
// Creating one variable per exit node, set to the constant matching the
// targeted external block.
@@ -184,12 +187,12 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
}
// Creating the switch to jump to the correct exit target.
- std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
- TargetToValue.begin(), TargetToValue.end());
- llvm::SwitchInst *Sw =
- Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
- for (size_t i = 1; i < CasesList.size(); i++)
- Sw->addCase(CasesList[i].second, CasesList[i].first);
+ llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
+ SortedExitTargets.size() - 1);
+ for (size_t i = 1; i < SortedExitTargets.size(); i++) {
+ BasicBlock *BB = SortedExitTargets[i];
+ Sw->addCase(TargetToValue[BB], BB);
+ }
// Fix exit branches to redirect to the new exit.
for (auto Exit : CR->Exits)
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
index b3fcdc978625f..e7b1b441405f6 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
@@ -66,7 +66,7 @@ while.end:
; CHECK: %[[#new_end]] = OpLabel
; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]]
-; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]]
+; CHECK: OpSwitch %[[#route]] %[[#if_then]] 1 %[[#while_end_loopexit]]
}
declare token @llvm.experimental.convergence.entry() #2
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
index a67c58fdd5749..593e3631c02b9 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
@@ -75,7 +75,7 @@ while.end:
; CHECK: %[[#new_end]] = OpLabel
; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]]
-; CHECK: OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]]
+; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#while_end]]
}
declare token @llvm.experimental.convergence.entry() #2
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
index 32a97553df05e..9806dd7955468 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
@@ -85,7 +85,7 @@ while.end:
; CHECK: %[[#new_end]] = OpLabel
; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]]
-; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]]
+; CHECK: OpSwitch %[[#route]] %[[#if_then]] 1 %[[#if_then2]] 2 %[[#while_end_loopexit]]
}
declare token @llvm.experimental.convergence.entry() #2
More information about the llvm-commits
mailing list