[llvm] [SPIR-V] Fix flakiness during switch generation. (PR #95001)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 10 09:30:57 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/95001

>From 442fd1cc9596e808a47c5d5cb82aa92f1f1ccb41 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 10 Jun 2024 14:00:49 +0200
Subject: [PATCH] [SPIR-V] Fix flakiness during switch generation.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The case-list of the switches generated by this pass were not
"deterministic" (based on allocation patterns).
This is because the CaseList order relied on an unordered_set
order.
Using the sorted exit target list for those should solve the problem.

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 .../SPIRV/SPIRVMergeRegionExitTargets.cpp     | 25 +++++++++++--------
 .../SPIRV/structurizer/merge-exit-break.ll    |  2 +-
 .../merge-exit-convergence-in-break.ll        |  2 +-
 .../structurizer/merge-exit-multiple-break.ll |  2 +-
 4 files changed, 17 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 2744c25d1bc75..52354281cdd7e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -17,6 +17,8 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
 #include "llvm/IR/CFG.h"
@@ -71,7 +73,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
   /// terminator will take.
   llvm::Value *createExitVariable(
       BasicBlock *BB,
-      const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+      const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
     auto *T = BB->getTerminator();
     if (isa<ReturnInst>(T))
       return nullptr;
@@ -103,7 +105,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
 
   /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
   void replaceBranchTargets(BasicBlock *BB,
-                            const std::unordered_set<BasicBlock *> ToReplace,
+                            const SmallPtrSet<BasicBlock *, 4> &ToReplace,
                             BasicBlock *NewTarget) {
     auto *T = BB->getTerminator();
     if (isa<ReturnInst>(T))
@@ -133,7 +135,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
                                        const SPIRV::ConvergenceRegion *CR) {
     // Gather all the exit targets for this region.
-    std::unordered_set<BasicBlock *> ExitTargets;
+    SmallPtrSet<BasicBlock *, 4> ExitTargets;
     for (BasicBlock *Exit : CR->Exits) {
       for (BasicBlock *Target : gatherSuccessors(Exit)) {
         if (CR->Blocks.count(Target) == 0)
@@ -164,9 +166,10 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
 
     // Creating one constant per distinct exit target. This will be route to the
     // correct target.
-    std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+    DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
     for (BasicBlock *Target : SortedExitTargets)
-      TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+      TargetToValue.insert(
+          std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
 
     // Creating one variable per exit node, set to the constant matching the
     // targeted external block.
@@ -184,12 +187,12 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     }
 
     // Creating the switch to jump to the correct exit target.
-    std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
-        TargetToValue.begin(), TargetToValue.end());
-    llvm::SwitchInst *Sw =
-        Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
-    for (size_t i = 1; i < CasesList.size(); i++)
-      Sw->addCase(CasesList[i].second, CasesList[i].first);
+    llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
+                                                SortedExitTargets.size() - 1);
+    for (size_t i = 1; i < SortedExitTargets.size(); i++) {
+      BasicBlock *BB = SortedExitTargets[i];
+      Sw->addCase(TargetToValue[BB], BB);
+    }
 
     // Fix exit branches to redirect to the new exit.
     for (auto Exit : CR->Exits)
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
index b3fcdc978625f..e7b1b441405f6 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
@@ -66,7 +66,7 @@ while.end:
 
 ; CHECK:   %[[#new_end]] = OpLabel
 ; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]]
-; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]]
+; CHECK:                   OpSwitch %[[#route]] %[[#if_then]] 1 %[[#while_end_loopexit]]
 }
 
 declare token @llvm.experimental.convergence.entry() #2
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
index a67c58fdd5749..593e3631c02b9 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
@@ -75,7 +75,7 @@ while.end:
 
 ; CHECK:   %[[#new_end]] = OpLabel
 ; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]]
-; CHECK:                   OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]]
+; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#while_end]]
 }
 
 declare token @llvm.experimental.convergence.entry() #2
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
index 32a97553df05e..9806dd7955468 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll
@@ -85,7 +85,7 @@ while.end:
 
 ; CHECK:   %[[#new_end]] = OpLabel
 ; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]]
-; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]]
+; CHECK:                   OpSwitch %[[#route]] %[[#if_then]] 1 %[[#if_then2]] 2 %[[#while_end_loopexit]]
 }
 
 declare token @llvm.experimental.convergence.entry() #2



More information about the llvm-commits mailing list