[llvm] [Utils][SPIR-V] Adding spirv-sim to LLVM (PR #104020)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 08:32:34 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/104020

>From e8358c77ada99c1e3baa300ee28bff34edca36a5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 14 Aug 2024 13:17:00 +0200
Subject: [PATCH 1/5] [Utils][SPIR-V] Adding spirv-sim to LLVM

Currently, the testing infrastructure for SPIR-V is based on FileCheck.
Those tests are great to check some level of codegen, but when the test
needs check both the CFG layout and the content of each basic-block,
things becomes messy.

- Because the CHECK/CHECK-DAG/CHECK-NEXT state is limited, it is sometimes
  hard to catch the good block: if 2 basic blocks have similar
  instructions, FileCheck can match the wrong one.

- Cross-lane interaction can be a bit difficult to understand, and writting
  a FileCheck test that is strong enough to catch bad CFG transforms while
  not being broken everytime some unrelated codegen part changes is hard.

And lastly, the spirv-val tooling we have checks that the generated
SPIR-V respects the spec, not that it is correct in regards to the
source IR.

For those reasons, I believe the best way to test the structurizer is
to:
 - run spirv-val to make sure the CFG respects the spec.
 - simulate the function to validate result for each lane, making sure
   the generated code is correct.

This simulator has no other dependencies than code python. It also only
supports a very limited set of instructions as we can test most features
through control-flow and some basic cross-lane interactions.

As-is, the added tests are just a harness for the simulator itself.
If this gets merged, the structurizer PR will benefit from this as I'll
be able to add extensive testing using this.
---
 llvm/test/Other/spirv-sim/branch.spv          |  42 ++
 llvm/test/Other/spirv-sim/call.spv            |  36 +
 llvm/test/Other/spirv-sim/lit.local.cfg       |   2 +
 llvm/test/Other/spirv-sim/loop.spv            |  58 ++
 .../Other/spirv-sim/simple-bad-result.spv     |  26 +
 llvm/test/Other/spirv-sim/simple.spv          |  22 +
 llvm/test/Other/spirv-sim/simulator-args.spv  |  36 +
 llvm/test/Other/spirv-sim/switch.spv          |  42 ++
 .../Other/spirv-sim/wave-get-lane-index.spv   |  30 +
 .../Other/spirv-sim/wave-read-lane-first.spv  |  83 +++
 llvm/test/lit.cfg.py                          |   2 +-
 llvm/utils/spirv-sim/instructions.py          | 387 +++++++++++
 llvm/utils/spirv-sim/spirv-sim.py             | 627 ++++++++++++++++++
 13 files changed, 1392 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Other/spirv-sim/branch.spv
 create mode 100644 llvm/test/Other/spirv-sim/call.spv
 create mode 100644 llvm/test/Other/spirv-sim/lit.local.cfg
 create mode 100644 llvm/test/Other/spirv-sim/loop.spv
 create mode 100644 llvm/test/Other/spirv-sim/simple-bad-result.spv
 create mode 100644 llvm/test/Other/spirv-sim/simple.spv
 create mode 100644 llvm/test/Other/spirv-sim/simulator-args.spv
 create mode 100644 llvm/test/Other/spirv-sim/switch.spv
 create mode 100644 llvm/test/Other/spirv-sim/wave-get-lane-index.spv
 create mode 100644 llvm/test/Other/spirv-sim/wave-read-lane-first.spv
 create mode 100644 llvm/utils/spirv-sim/instructions.py
 create mode 100755 llvm/utils/spirv-sim/spirv-sim.py

diff --git a/llvm/test/Other/spirv-sim/branch.spv b/llvm/test/Other/spirv-sim/branch.spv
new file mode 100644
index 00000000000000..7ce0e7da3f058b
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/branch.spv
@@ -0,0 +1,42 @@
+; RUN: spirv-sim --function=simple --wave=3 --expects=5,6,6 -i %s
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+               OpCapability Shader
+               OpCapability GroupNonUniform
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %WaveIndex
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+                OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+       %bool = OpTypeBool
+      %int_2 = OpConstant %int 2
+      %int_5 = OpConstant %int 5
+      %int_6 = OpConstant %int 6
+     %uint_0 = OpConstant %uint 0
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+%simple_type = OpTypeFunction %int
+  %uint_iptr = OpTypePointer Input %uint
+  %WaveIndex = OpVariable %uint_iptr Input
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+          %2 = OpLoad %uint %WaveIndex
+          %3 = OpIEqual %bool %uint_0 %2
+               OpSelectionMerge %merge None
+               OpBranchConditional %3 %true %false
+       %true = OpLabel
+               OpBranch %merge
+      %false = OpLabel
+               OpBranch %merge
+      %merge = OpLabel
+          %4 = OpPhi %int %int_5 %true %int_6 %false
+               OpReturnValue %4
+               OpFunctionEnd
+
diff --git a/llvm/test/Other/spirv-sim/call.spv b/llvm/test/Other/spirv-sim/call.spv
new file mode 100644
index 00000000000000..320b048f95296c
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/call.spv
@@ -0,0 +1,36 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=1 --expects=2 -i %s
+               OpCapability Shader
+               OpCapability GroupNonUniform
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %WaveIndex
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+                OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+     %uint_2 = OpConstant %uint 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+%simple_type = OpTypeFunction %int
+   %sub_type = OpTypeFunction %uint
+  %uint_iptr = OpTypePointer Input %uint
+  %WaveIndex = OpVariable %uint_iptr Input
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+        %sub = OpFunction %uint None %sub_type
+          %a = OpLabel
+               OpReturnValue %uint_2
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+          %2 = OpFunctionCall %uint %sub
+          %3 = OpBitcast %int %2
+               OpReturnValue %3
+               OpFunctionEnd
+
+
diff --git a/llvm/test/Other/spirv-sim/lit.local.cfg b/llvm/test/Other/spirv-sim/lit.local.cfg
new file mode 100644
index 00000000000000..d343a8f2bd9b09
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/lit.local.cfg
@@ -0,0 +1,2 @@
+spirv_sim_root = os.path.join(config.llvm_src_root, "utils", "spirv-sim")
+config.substitutions.append(("spirv-sim", os.path.join(spirv_sim_root, "spirv-sim.py")))
diff --git a/llvm/test/Other/spirv-sim/loop.spv b/llvm/test/Other/spirv-sim/loop.spv
new file mode 100644
index 00000000000000..c753ea2c149410
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/loop.spv
@@ -0,0 +1,58 @@
+; RUN: spirv-sim --function=simple --wave=4 --expects=0,2,2,4 -i %s
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+               OpCapability Shader
+               OpCapability GroupNonUniform
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %WaveIndex
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+                OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+       %bool = OpTypeBool
+      %int_2 = OpConstant %int 2
+      %int_5 = OpConstant %int 5
+      %int_6 = OpConstant %int 6
+     %uint_0 = OpConstant %uint 0
+     %uint_2 = OpConstant %uint 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+%simple_type = OpTypeFunction %int
+  %uint_iptr = OpTypePointer Input %uint
+  %uint_fptr = OpTypePointer Function %uint
+  %WaveIndex = OpVariable %uint_iptr Input
+       %main = OpFunction %void None %main_type
+      %unused = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+      %entry = OpLabel
+; uint i = 0;
+          %i = OpVariable %uint_fptr Function
+          %1 = OpLoad %uint %WaveIndex
+               OpStore %i %uint_0
+               OpBranch %header
+     %header = OpLabel
+          %2 = OpLoad %uint %i
+          %3 = OpULessThan %bool %2 %1
+               OpLoopMerge %merge %continue None
+               OpBranchConditional %3 %body %merge
+; while (i < WaveGetLaneIndex()) {
+;     i += 2;
+; }
+       %body = OpLabel
+               OpBranch %continue
+   %continue = OpLabel
+          %4 = OpIAdd %uint %2 %uint_2
+               OpStore %i %4
+               OpBranch %header
+      %merge = OpLabel
+; return (int) i;
+          %5 = OpLoad %uint %i
+          %6 = OpBitcast %int %5
+               OpReturnValue %6
+               OpFunctionEnd
+
+
diff --git a/llvm/test/Other/spirv-sim/simple-bad-result.spv b/llvm/test/Other/spirv-sim/simple-bad-result.spv
new file mode 100644
index 00000000000000..f4dd046cc078bc
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/simple-bad-result.spv
@@ -0,0 +1,26 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: not spirv-sim --function=simple --wave=1 --expects=1 -i %s 2>&1 | FileCheck %s
+
+; CHECK: Expected != Observed
+; CHECK: [1] != [2]
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+      %int_2 = OpConstant %int 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+   %simple_type = OpTypeFunction %int
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+               OpReturnValue %int_2
+               OpFunctionEnd
+
diff --git a/llvm/test/Other/spirv-sim/simple.spv b/llvm/test/Other/spirv-sim/simple.spv
new file mode 100644
index 00000000000000..8c06192ea6e3d4
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/simple.spv
@@ -0,0 +1,22 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=1 --expects=2 -i %s
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+      %int_2 = OpConstant %int 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+   %simple_type = OpTypeFunction %int
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+               OpReturnValue %int_2
+               OpFunctionEnd
diff --git a/llvm/test/Other/spirv-sim/simulator-args.spv b/llvm/test/Other/spirv-sim/simulator-args.spv
new file mode 100644
index 00000000000000..d8b10180641584
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/simulator-args.spv
@@ -0,0 +1,36 @@
+; RUN: not spirv-sim --function=simple --wave=a --expects=2 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-WAVE
+; RUN: not spirv-sim --function=simple --wave=1 --expects=a -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-EXPECT
+; RUN: not spirv-sim --function=simple --wave=1 --expects=1, -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-EXPECT
+; RUN: not spirv-sim --function=simple --wave=2 --expects=1 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-SIZE
+; RUN: not spirv-sim --function=foo --wave=1 --expects=1 -i %s 2>&1 | FileCheck %s --check-prefixes=CHECK-NAME
+
+; CHECK-WAVE: Invalid format for --wave/-w flag.
+
+; CHECK-EXPECT: Invalid format for --expects/-e flag.
+
+; CHECK-SIZE: Wave size != expected result array size
+
+; CHECK-NAME:          'foo' function not found. Known functions are:
+; CHECK-NAME-NEXT:     - main
+; CHECK-NAME-NEXT:     - simple
+; CHECK-NANE-NOT-NEXT: -
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+      %int_2 = OpConstant %int 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+   %simple_type = OpTypeFunction %int
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+               OpReturnValue %int_2
+               OpFunctionEnd
diff --git a/llvm/test/Other/spirv-sim/switch.spv b/llvm/test/Other/spirv-sim/switch.spv
new file mode 100644
index 00000000000000..83dc56cecef2aa
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/switch.spv
@@ -0,0 +1,42 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,0 -i %s
+               OpCapability Shader
+               OpCapability GroupNonUniform
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %WaveIndex
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+                OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+       %bool = OpTypeBool
+      %int_0 = OpConstant %int 0
+      %int_1 = OpConstant %int 1
+      %int_2 = OpConstant %int 2
+     %uint_0 = OpConstant %uint 0
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+%simple_type = OpTypeFunction %int
+  %uint_iptr = OpTypePointer Input %uint
+  %WaveIndex = OpVariable %uint_iptr Input
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+          %2 = OpLoad %uint %WaveIndex
+               OpSelectionMerge %merge None
+               OpSwitch %2 %default 1 %case_1 2 %case_2
+    %default = OpLabel
+               OpBranch %merge
+     %case_1 = OpLabel
+               OpBranch %merge
+     %case_2 = OpLabel
+               OpBranch %merge
+      %merge = OpLabel
+          %4 = OpPhi %int %int_0 %default %int_1 %case_1 %int_2 %case_2
+               OpReturnValue %4
+               OpFunctionEnd
diff --git a/llvm/test/Other/spirv-sim/wave-get-lane-index.spv b/llvm/test/Other/spirv-sim/wave-get-lane-index.spv
new file mode 100644
index 00000000000000..1c1e5e8aefd4f9
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/wave-get-lane-index.spv
@@ -0,0 +1,30 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,3 -i %s
+               OpCapability Shader
+               OpCapability GroupNonUniform
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %WaveIndex
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %simple "simple"
+               OpName %main "main"
+                OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+      %int_2 = OpConstant %int 2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+%simple_type = OpTypeFunction %int
+  %uint_iptr = OpTypePointer Input %uint
+  %WaveIndex = OpVariable %uint_iptr Input
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+     %simple = OpFunction %int None %simple_type
+          %1 = OpLabel
+          %2 = OpLoad %uint %WaveIndex
+          %3 = OpBitcast %int %2
+               OpReturnValue %3
+               OpFunctionEnd
+
diff --git a/llvm/test/Other/spirv-sim/wave-read-lane-first.spv b/llvm/test/Other/spirv-sim/wave-read-lane-first.spv
new file mode 100644
index 00000000000000..801fb55fbaa9f6
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/wave-read-lane-first.spv
@@ -0,0 +1,83 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=4 --expects=0,1,2,0 -i %s
+
+; int simple() {
+;   int m[4] = { 0, 1, 2, 0 };
+;   int idx = WaveGetLaneIndex();
+;   for (int i = 0; i < 4; i++) {
+;     if (i == m[idx]) {
+;       return WaveReadLaneFirst(idx);
+;     }
+;   }
+;   return 0;
+; }
+                       OpCapability Shader
+                       OpCapability GroupNonUniform
+                       OpCapability GroupNonUniformBallot
+                       OpMemoryModel Logical GLSL450
+                       OpEntryPoint GLCompute %main "main" %WaveIndex
+                       OpExecutionMode %main LocalSize 1 1 1
+                       OpSource HLSL 670
+                       OpName %simple "simple"
+                       OpName %main "main"
+                       OpDecorate %WaveIndex BuiltIn SubgroupLocalInvocationId
+                %int = OpTypeInt 32 1
+               %uint = OpTypeInt 32 0
+               %bool = OpTypeBool
+              %int_0 = OpConstant %int 0
+              %int_1 = OpConstant %int 1
+              %int_2 = OpConstant %int 2
+              %int_4 = OpConstant %int 4
+             %uint_3 = OpConstant %uint 3
+             %uint_4 = OpConstant %uint 4
+               %void = OpTypeVoid
+          %main_type = OpTypeFunction %void
+        %simple_type = OpTypeFunction %int
+          %uint_iptr = OpTypePointer Input %uint
+           %int_fptr = OpTypePointer Function %int
+     %arr_int_uint_4 = OpTypeArray %int %uint_4
+%arr_int_uint_4_fptr = OpTypePointer Function %arr_int_uint_4
+          %WaveIndex = OpVariable %uint_iptr Input
+               %main = OpFunction %void None %main_type
+              %entry = OpLabel
+                       OpReturn
+                       OpFunctionEnd
+             %simple = OpFunction %int None %simple_type
+         %bb_entry_0 = OpLabel
+                  %m = OpVariable %arr_int_uint_4_fptr Function
+                %idx = OpVariable %int_fptr Function
+                  %i = OpVariable %int_fptr Function
+                 %27 = OpCompositeConstruct %arr_int_uint_4 %int_0 %int_1 %int_2 %int_0
+                       OpStore %m %27
+                 %28 = OpLoad %uint %WaveIndex
+                 %29 = OpBitcast %int %28
+                       OpStore %idx %29
+                       OpStore %i %int_0
+                       OpBranch %for_check
+          %for_check = OpLabel
+                 %31 = OpLoad %int %i
+                 %33 = OpSLessThan %bool %31 %int_4
+                       OpLoopMerge %for_merge %for_continue None
+                       OpBranchConditional %33 %for_body %for_merge
+           %for_body = OpLabel
+                 %37 = OpLoad %int %i
+                 %38 = OpLoad %int %idx
+                 %39 = OpAccessChain %int_fptr %m %38
+                 %40 = OpLoad %int %39
+                 %41 = OpIEqual %bool %37 %40
+                       OpSelectionMerge %if_merge None
+                       OpBranchConditional %41 %if_true %if_merge
+            %if_true = OpLabel
+                 %44 = OpLoad %int %idx
+                 %45 = OpGroupNonUniformBroadcastFirst %int %uint_3 %44
+                       OpReturnValue %45
+           %if_merge = OpLabel
+                       OpBranch %for_continue
+       %for_continue = OpLabel
+                 %47 = OpLoad %int %i
+                 %48 = OpIAdd %int %47 %int_1
+                       OpStore %i %48
+                       OpBranch %for_check
+          %for_merge = OpLabel
+                       OpReturnValue %int_0
+                       OpFunctionEnd
diff --git a/llvm/test/lit.cfg.py b/llvm/test/lit.cfg.py
index e5e3dc7e1b4bd0..03448bd4174413 100644
--- a/llvm/test/lit.cfg.py
+++ b/llvm/test/lit.cfg.py
@@ -22,7 +22,7 @@
 
 # suffixes: A list of file extensions to treat as test files. This is overriden
 # by individual lit.local.cfg files in the test subdirectories.
-config.suffixes = [".ll", ".c", ".test", ".txt", ".s", ".mir", ".yaml"]
+config.suffixes = [".ll", ".c", ".test", ".txt", ".s", ".mir", ".yaml", ".spv"]
 
 # excludes: A list of directories to exclude from the testsuite. The 'Inputs'
 # subdirectories contain auxiliary inputs for various tests in their parent
diff --git a/llvm/utils/spirv-sim/instructions.py b/llvm/utils/spirv-sim/instructions.py
new file mode 100644
index 00000000000000..5f4ef831de0645
--- /dev/null
+++ b/llvm/utils/spirv-sim/instructions.py
@@ -0,0 +1,387 @@
+# Base class for an instruction. To implement a basic instruction that doesn't
+# impact the control-flow, create a new class inheriting from this.
+class Instruction:
+    _result: str
+
+    def __init__(self, line: str):
+        # Contains the name of the output register, if any.
+        self._result: str = None
+        # Contains the instruction opcode.
+        self._opcode: str = None
+        # Contains all the instruction operands, except result and opcode.
+        self._operands: list[str] = []
+
+        self.line = line
+        tokens = line.split()
+        if len(tokens) > 1 and tokens[1] == "=":
+            self._result = tokens[0]
+            self._opcode = tokens[2]
+            self._operands = tokens[3:] if len(tokens) > 2 else []
+        else:
+            self._result = None
+            self._opcode = tokens[0]
+            self._operands = tokens[1:] if len(tokens) > 1 else []
+
+    def __str__(self):
+        if self._result is None:
+            return f"      {self._opcode} {self._operands}"
+        return f"{self._result:3} = {self._opcode} {self._operands}"
+
+    # Returns the instruction opcode.
+    def opcode(self) -> str:
+        return self._opcode
+
+    # Returns the instruction operands.
+    def operands(self) -> list[str]:
+        return self._operands
+
+    # Returns the instruction output register. Calling this function is
+    # only allowed if has_output_register() is true.
+    def output_register(self) -> str:
+        assert self.has_output_register()
+        return self._result
+
+    # Returns true if this function has an output register. False otherwise.
+    def has_output_register(self) -> bool:
+        return self._result is not None
+
+    # This function is used to initialize state related to this instruction
+    # before module execution begins. For example, global Input variables
+    # can use this to store the lane ID into the register.
+    def static_execution(self, lane):
+        pass
+
+    # This function is called everytime this instruction is executed by a
+    # tangle. This function should not be directly overriden, instead see
+    # _impl and _advance_ip.
+    def runtime_execution(self, module, lane):
+        self._impl(module, lane)
+        self._advance_ip(module, lane)
+
+    # This function needs to be overriden if your instruction can be executed.
+    # It implements the logic of the instruction.
+    # 'Static' instructions like OpConstant should not override this since
+    # they are not supposed to be executed at runtime.
+    def _impl(self, module, lane):
+        raise RuntimeError(f"Unimplemented instruction {self}")
+
+    # By default, IP is incremented to point to the next instruction.
+    # If the instruction modifies IP (like OpBranch), this must be overridden.
+    def _advance_ip(self, module, lane):
+        lane.set_ip(lane.ip() + 1)
+
+
+# Those are parsed, but never executed.
+class OpEntryPoint(Instruction):
+    pass
+
+
+class OpFunction(Instruction):
+    pass
+
+
+class OpFunctionEnd(Instruction):
+    pass
+
+
+class OpLabel(Instruction):
+    pass
+
+
+class OpVariable(Instruction):
+    pass
+
+
+class OpName(Instruction):
+    def name(self) -> str:
+        return self._operands[1][1:-1]
+
+    def decoratedRegister(self) -> str:
+        return self._operands[0]
+
+
+# The only decoration we use if the BuilIn one to initialize the values.
+class OpDecorate(Instruction):
+    def static_execution(self, lane):
+        if self._operands[1] == "LinkageAttributes":
+            return
+
+        assert (
+            self._operands[1] == "BuiltIn"
+            and self._operands[2] == "SubgroupLocalInvocationId"
+        )
+        lane.set_register(self._operands[0], lane.tid())
+
+
+# Constants
+class OpConstant(Instruction):
+    def static_execution(self, lane):
+        lane.set_register(self._result, int(self._operands[1]))
+
+
+class OpConstantTrue(OpConstant):
+    def static_execution(self, lane):
+        lane.set_register(self._result, True)
+
+
+class OpConstantFalse(OpConstant):
+    def static_execution(self, lane):
+        lane.set_register(self._result, False)
+
+
+class OpConstantComposite(OpConstant):
+    def static_execution(self, lane):
+        result = []
+        length = self.get_register(self._operands[0])
+        for op in self._operands[1:]:
+            result.append(self.get_register(op))
+        lane.set_register(self._result, result)
+
+
+class OpConstantComposite(OpConstant):
+    def static_execution(self, vm, state):
+        output = []
+        for op in self._operands[1:]:
+            output.append(state.get_register(op))
+        state.set_register(self._result, output)
+
+
+# Control flow instructions
+class OpFunctionCall(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        entry = module.get_function_entry(self._operands[1])
+        lane.do_call(entry, self._result)
+
+
+class OpReturn(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        lane.do_return(None)
+
+
+class OpReturnValue(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        lane.do_return(lane.get_register(self._operands[0]))
+
+
+class OpBranch(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        lane.set_ip(module.get_bb_entry(self._operands[0]))
+        pass
+
+
+class OpBranchConditional(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        condition = lane.get_register(self._operands[0])
+        if condition:
+            lane.set_ip(module.get_bb_entry(self._operands[1]))
+        else:
+            lane.set_ip(module.get_bb_entry(self._operands[2]))
+
+
+class OpSwitch(Instruction):
+    def _impl(self, module, lane):
+        pass
+
+    def _advance_ip(self, module, lane):
+        value = lane.get_register(self._operands[0])
+        default_label = self._operands[1]
+        i = 2
+        while i < len(self._operands):
+            imm = int(self._operands[i])
+            label = self._operands[i + 1]
+            if value == imm:
+                lane.set_ip(module.get_bb_entry(label))
+                return
+            i += 2
+        lane.set_ip(module.get_bb_entry(default_label))
+
+
+class OpUnreachable(Instruction):
+    def _impl(self, module, lane):
+        raise RuntimeError("This instruction should never be executed.")
+
+
+# Convergence instructions
+class _MergeInstruction(Instruction):
+    def merge_location(self):
+        return self._operands[0]
+
+    def continue_location(self):
+        return None if len(self._operands) < 3 else self._operands[1]
+
+    def _impl(self, module, lane):
+        lane.handle_convergence_header(self)
+
+
+class OpLoopMerge(_MergeInstruction):
+    pass
+
+
+class OpSelectionMerge(_MergeInstruction):
+    pass
+
+
+# Other instructions
+class OpBitcast(Instruction):
+    def _impl(self, module, lane):
+        # TODO: find out the type from the defining instruction.
+        # This can only work for DXC.
+        if self._operands[0] == "%int":
+            lane.set_register(self._result, int(lane.get_register(self._operands[1])))
+        else:
+            raise RuntimeError("Unsupported OpBitcast operand")
+
+
+class OpAccessChain(Instruction):
+    def _impl(self, module, lane):
+        # Python dynamic types allows me to simplify. As long as the SPIR-V
+        # is legal, this should be fine.
+        # Note: SPIR-V structs are stored as tuples
+        value = lane.get_register(self._operands[1])
+        for operand in self._operands[2:]:
+            value = value[lane.get_register(operand)]
+        lane.set_register(self._result, value)
+
+
+class OpCompositeConstruct(Instruction):
+    def _impl(self, module, lane):
+        output = []
+        for op in self._operands[1:]:
+            output.append(lane.get_register(op))
+        lane.set_register(self._result, output)
+
+
+class OpStore(Instruction):
+    def _impl(self, module, lane):
+        lane.set_register(self._operands[0], lane.get_register(self._operands[1]))
+
+
+class OpLoad(Instruction):
+    def _impl(self, module, lane):
+        lane.set_register(self._result, lane.get_register(self._operands[1]))
+
+
+class OpIAdd(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS + RHS)
+
+
+class OpISub(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS - RHS)
+
+
+class OpIMul(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS * RHS)
+
+
+class OpLogicalNot(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        lane.set_register(self._result, not LHS)
+
+
+class OpSGreaterThan(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS > RHS)
+
+
+class _LessThan(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS < RHS)
+
+
+class _GreaterThan(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS > RHS)
+
+
+class OpSLessThan(_LessThan):
+    pass
+
+
+class OpULessThan(_LessThan):
+    pass
+
+
+class OpSGreaterThan(_GreaterThan):
+    pass
+
+
+class OpUGreaterThan(_GreaterThan):
+    pass
+
+
+class OpIEqual(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS == RHS)
+
+
+class OpINotEqual(Instruction):
+    def _impl(self, module, lane):
+        LHS = lane.get_register(self._operands[1])
+        RHS = lane.get_register(self._operands[2])
+        lane.set_register(self._result, LHS != RHS)
+
+
+class OpPhi(Instruction):
+    def _impl(self, module, lane):
+        previousBBName = lane.get_previous_bb_name()
+        i = 1
+        while i < len(self._operands):
+            label = self._operands[i + 1]
+            if label == previousBBName:
+                lane.set_register(self._result, lane.get_register(self._operands[i]))
+                return
+            i += 2
+        raise RuntimeError("previousBB not in the OpPhi _operands")
+
+
+class OpSelect(Instruction):
+    def _impl(self, module, lane):
+        condition = lane.get_register(self._operands[1])
+        value = lane.get_register(self._operands[2 if condition else 3])
+        lane.set_register(self._result, value)
+
+
+# Wave intrinsics
+class OpGroupNonUniformBroadcastFirst(Instruction):
+    def _impl(self, module, lane):
+        assert lane.get_register(self._operands[1]) == 3
+        if lane.is_first_active_lane():
+            lane.broadcast_register(self._result, lane.get_register(self._operands[2]))
+
+
+class OpGroupNonUniformElect(Instruction):
+    def _impl(self, module, lane):
+        lane.set_register(self._result, lane.is_first_active_lane())
diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py
new file mode 100755
index 00000000000000..e5b01a2445cf08
--- /dev/null
+++ b/llvm/utils/spirv-sim/spirv-sim.py
@@ -0,0 +1,627 @@
+#!/usr/bin/env python3
+
+import fileinput
+import inspect
+from typing import Any
+from dataclasses import dataclass
+import sys
+from instructions import *
+import argparse
+import re
+
+RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
+
+# Parse the SPIR-V instructions. Some instructions are ignored because
+# not required to simulate this module.
+# Instructions are to be implemented in instructions.py
+def parseInstruction(i):
+    IGNORED = set(
+        [
+            "OpCapability",
+            "OpMemoryModel",
+            "OpExecutionMode",
+            "OpExtension",
+            "OpSource",
+            "OpTypeInt",
+            "OpTypeFloat",
+            "OpTypeBool",
+            "OpTypeVoid",
+            "OpTypeFunction",
+            "OpTypePointer",
+            "OpTypeArray",
+        ]
+    )
+    if i.opcode() in IGNORED:
+        return None
+
+    try:
+        Type = getattr(sys.modules["instructions"], i.opcode())
+    except AttributeError:
+        raise RuntimeError(f"Unsupported instruction {i}")
+    if not inspect.isclass(Type):
+        raise RuntimeError(
+            f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
+        )
+    return Type(i.line)
+
+# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
+# The delimiter is the first instruction of the next piece.
+# This function returns no empty pieces:
+# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
+#   with the delimiter and following instructions.
+# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
+def splitInstructions(
+    splitType: type, instructions: list[Instruction]
+) -> list[list[Instruction]]:
+    blocks = [[]]
+    for instruction in instructions:
+        if type(instruction) is splitType and len(blocks[-1]) > 0:
+            blocks.append([])
+        blocks[-1].append(instruction)
+    return blocks
+
+# Defines a BasicBlock in the simulator.
+# Begins at an OpLabel, and ends with a control-flow instruction.
+class BasicBlock:
+    def __init__(self, instructions):
+        assert type(instructions[0]) is OpLabel
+        # The name of the basic block, which is the register of the leading
+        # OpLabel.
+        self._name = instructions[0].output_register()
+        # The list of instructions belonging to this block.
+        self._instructions = instructions[1:]
+
+    # Returns the name of this basic block.
+    def name(self):
+        return self._name
+
+    # Returns the instruction at index in this basic block.
+    def __getitem__(self, index: int) -> Instruction:
+        return self._instructions[index]
+
+    # Returns the number of instructions in this basic block, excluding the
+    # leading OpLabel.
+    def __len__(self):
+        return len(self._instructions)
+
+    def dump(self):
+        print(f"        {self._name}:")
+        for instruction in self._instructions:
+            print(f"        {instruction}")
+
+
+# Defines a Function in the simulator.
+class Function:
+    def __init__(self, instructions):
+        assert type(instructions[0]) is OpFunction
+        # The name of the function (name of the register returned by OpFunction).
+        self._name: str = instructions[0].output_register()
+        # The list of basic blocks that belongs to this function.
+        self._basic_blocks: list[BasicBlock] = []
+        # The variables local to this function.
+        self._variables: list[OpVariable] = [
+            x for x in instructions if type(x) is OpVariable
+        ]
+
+        assert type(instructions[0]) is OpFunction
+        assert type(instructions[-1]) is OpFunctionEnd
+        body = filter(lambda x: type(x) != OpVariable, instructions[1:-1])
+        for block in splitInstructions(OpLabel, body):
+            self._basic_blocks.append(BasicBlock(block))
+
+    # Returns the name of this function.
+    def name(self) -> str:
+        return self._name
+
+    # Returns the basic block at index in this function.
+    def __getitem__(self, index: int) -> BasicBlock:
+        return self._basic_blocks[index]
+
+    # Returns the index of the basic block with the given name if found,
+    # -1 otherwise.
+    def get_bb_index(self, name) -> int:
+        for i in range(len(self._basic_blocks)):
+            if self._basic_blocks[i].name() == name:
+                return i
+        return -1
+
+    def dump(self):
+        print("      Variables:")
+        for var in self._variables:
+            print(f"        {var}")
+        print("      Blocks:")
+        for bb in self._basic_blocks:
+            bb.dump()
+
+
+# Represents an instruction pointer in the simulator.
+ at dataclass
+class InstructionPointer:
+    # The current function the IP points to.
+    function: Function
+    # The basic block index in function IP points to.
+    basic_block: int
+    # The instruction in basic_block IP points to.
+    instruction_index: int
+
+    def __str__(self):
+        bb = self.function[self.basic_block]
+        i = bb[self.instruction_index]
+        return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
+
+    def __hash__(self):
+        return hash((self.function.name(), self.basic_block, self.instruction_index))
+
+    # Returns the basic block IP points to.
+    def bb(self) -> BasicBlock:
+        return self.function[self.basic_block]
+
+    # Returns the instruction IP points to.
+    def instruction(self):
+        return self.function[self.basic_block][self.instruction_index]
+
+    # Increment IP by 1. This only works inside a basic-block boundary.
+    # Incrementing IP when at the boundary of a basic block will fail.
+    def __add__(self, value: int):
+        bb = self.function[self.basic_block]
+        assert len(bb) > self.instruction_index + value
+        return InstructionPointer(
+            self.function, self.basic_block, self.instruction_index + 1
+        )
+
+
+# Defines a Lane in this simulator.
+class Lane:
+    def __init__(self, wave, tid):
+        # The registers known by this lane.
+        self._registers: dict[str, Any] = {}
+        # The current IP of this lane.
+        self._ip: InstructionPointer = None
+        # If this lane running.
+        self._running: bool = True
+        # The wave this lane belongs to.
+        self._wave: Wave = wave
+        # The callstack of this lane. Each tuple represents 1 call.
+        #   The first element is the IP the function will return to.
+        #   The second element is the callback to call to store the return value
+        #   into the correct register.
+        self._callstack: list[tuple(InstructionPointer, Callback)] = []
+
+        # The index of this lane in the wave.
+        self._tid = tid
+        # The last BB this lane was executing into.
+        self._previous_bb = None
+        # The current BB this lane is executing into.
+        self._current_bb = None
+
+    # Returns the lane/thread ID of this lane in its wave.
+    def tid(self):
+        return self._tid
+
+    # Returns true is this lane if the first by index in the current active tangle.
+    def is_first_active_lane(self):
+        return self._tid == self._wave.get_first_active_lane_index()
+
+    # Broadcast value into the registers of all active lanes.
+    def broadcast_register(self, register, value):
+        self._wave.broadcast_register(register, value)
+
+    # Returns the IP this lane is currently at.
+    def ip(self):
+        return self._ip
+
+    # Returns true if this lane is running, false otherwise.
+    # Running means not dead. An inactive lane is running.
+    def running(self):
+        return self._running
+
+    # Set the register at "name" to "value" in this lane.
+    def set_register(self, name, value):
+        self._registers[name] = value
+
+    # Get the value in register "name" in this lane.
+    # if allow_undef is true, fetching an unknown register won't fail.
+    def get_register(self, name, allow_undef=False):
+        if allow_undef and name not in self._registers:
+            return None
+        return self._registers[name]
+
+    def set_ip(self, ip):
+        if ip.bb() != self._current_bb:
+            self._previous_bb = self._current_bb
+            self._current_bb = ip.bb()
+        self._ip = ip
+
+    def get_previous_bb_name(self):
+        return self._previous_bb.name()
+
+    def handle_convergence_header(self, instruction):
+        self._wave.handle_convergence_header(self, instruction)
+
+    def do_call(self, ip, output_register):
+        return_ip = None if self._ip is None else self._ip + 1
+        self._callstack.append(
+            (return_ip, lambda value: self.set_register(output_register, value))
+        )
+        self.set_ip(ip)
+
+    def do_return(self, value):
+        ip, callback = self._callstack[-1]
+        self._callstack.pop()
+
+        callback(value)
+        if len(self._callstack) == 0:
+            self._running = False
+        else:
+            self.set_ip(ip)
+
+
+# Represents the SPIR-V module in the simulator.
+class Module:
+    def __init__(self, instructions):
+        chunks = splitInstructions(OpFunction, instructions)
+
+        # The instructions located outside of all functions.
+        self._prolog = chunks[0]
+        # The functions in this module.
+        self._functions = {}
+        # Global variables in this module.
+        self._globals = [
+            x
+            for x in instructions
+            if type(x) is OpVariable or issubclass(type(x), OpConstant)
+        ]
+
+        # Helper dictionaries to get real names of registers, or registers by names.
+        self._name2reg = {}
+        self._reg2name = {}
+        for instruction in instructions:
+            if type(instruction) is OpName:
+                name = instruction.name()
+                reg = instruction.decoratedRegister()
+                self._name2reg[name] = reg
+                self._reg2name[reg] = name
+
+        for chunk in chunks[1:]:
+            function = Function(chunk)
+            assert function.name() not in self._functions
+            self._functions[function.name()] = function
+
+    # Returns the register matching "name" if any, None otherwise.
+    # This assumes names are unique.
+    def getRegisterFromName(self, name):
+        if name in self._name2reg:
+            return self._name2reg[name]
+        return None
+
+    # Returns the name given to "register" if any, None otherwise.
+    def getNameFromRegister(self, register):
+        if register in self._reg2name:
+            return self._reg2name[register]
+        return None
+
+    # Initialize the module before wave execution begins.
+    # See Instruction::static_execution for more details.
+    def initialize(self, lane):
+        for instruction in self._globals:
+            instruction.static_execution(lane)
+
+        # Initialize builtins
+        for instruction in self._prolog:
+            if type(instruction) is OpDecorate:
+                instruction.static_execution(lane)
+
+    def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
+        ip.instruction().runtime_execution(self, lane)
+
+    # Returns the first valid IP for the function defined by the given register.
+    # Calling this with a register not returned by OpFunction is illegal.
+    def get_function_entry(self, register: str) -> InstructionPointer:
+        if register not in self._functions:
+            raise RuntimeError(f"Function defining {register} not found.")
+        return InstructionPointer(self._functions[register], 0, 0)
+
+    # Returns the first valid IP for the basic block defined by register.
+    # Calling this with a register not returned by an OpLabel is illegal.
+    def get_bb_entry(self, register: str) -> InstructionPointer:
+        for name, function in self._functions.items():
+            index = function.get_bb_index(register)
+            if index != -1:
+                return InstructionPointer(function, index, 0)
+        raise RuntimeError(f"Instruction defining {register} not found.")
+
+    # Returns the list of function names in this module.
+    # If an OpName exists for this function, returns the pretty name, else
+    # returns the register name.
+    def get_function_names(self):
+        return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
+
+    # Returns the global variables defined in this module.
+    def variables(self) -> iter:
+        return [x.output_register() for x in self._globals]
+
+    def dump(self, function_name: str = None):
+        print("Module:")
+        print("  globals:")
+        for instruction in self._globals:
+            print(f"    {instruction}")
+
+        if function_name is None:
+            print("  functions:")
+            for register, function in self._functions.items():
+                name = self.getNameFromRegister(register)
+                print(f"  Function {register} ({name})")
+                function.dump()
+            return
+
+        register = self.getRegisterFromName(function_name)
+        print(f"  function {register} ({function_name}):")
+        if register is not None:
+            self._functions[register].dump()
+        else:
+            print(f"    error: cannot find function.")
+
+
+# Defines a convergence requirement for the simulation:
+# A list of lanes impacted by a merge and possibly the associated
+# continue target.
+ at dataclass
+class ConvergenceRequirement:
+    mergeTarget: InstructionPointer
+    continueTarget: InstructionPointer
+    impactedLanes: set[int]
+
+
+# Defines a Lane group/Wave in the simulator.
+class Wave:
+    def __init__(self, module, wave_size: int):
+        assert wave_size > 0
+        # The module this wave will execute.
+        self._module = module
+        # The lanes this wave will be composed of.
+        self._lanes = []
+        for i in range(wave_size):
+            self._lanes.append(Lane(self, i))
+
+        # The instructions scheduled for execution.
+        self._tasks: dict(InstructionPointer, list[Lane]) = {}
+        # The actual requirements to comply with when executing instructions.
+        # e.g: the set of lanes required to merge before executing the merge block.
+        self._convergence_requirements = []
+        # The indices of the active lanes for the current executing instruction.
+        self._active_lane_indices = set()
+
+    # Returns True if the given IP can be executed for the given list of lanes.
+    def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
+        merged_lanes = set()
+        for lane in self._lanes:
+            if not lane.running():
+                merged_lanes.add(lane)
+
+        for requirement in self._convergence_requirements:
+            # This task is not executing a merge or continue target.
+            # Adding all lanes at those points into the ignore list.
+            if requirement.mergeTarget != ip and requirement.continueTarget != ip:
+                for tid in requirement.impactedLanes:
+                    if self._lanes[tid].ip() == requirement.mergeTarget:
+                        merged_lanes.add(tid)
+                    if self._lanes[tid].ip() == requirement.continueTarget:
+                        merged_lanes.add(tid)
+                continue
+
+            # This task is executing the current requirement continue/merge
+            # target.
+            for tid in requirement.impactedLanes:
+                lane = self._lanes[tid]
+                if not lane.running():
+                    continue
+
+                if lane.tid() in merged_lanes:
+                    continue
+
+                if ip == requirement.mergeTarget:
+                    if lane.ip() != requirement.mergeTarget:
+                        return False
+                else:
+                    if (
+                        lane.ip() != requirement.mergeTarget
+                        and lane.ip() != requirement.continueTarget
+                    ):
+                        return False
+        return True
+
+    # Returns the next task we can schedule. This must always return a task.
+    # Calling this when all lanes are dead is invalid.
+    def _get_next_runnable_task(self):
+        candidate = None
+        for ip, lanes in self._tasks.items():
+            if len(lanes) == 0:
+                continue
+            if self._is_task_candidate(ip, lanes):
+                candidate = ip
+                break
+
+        if candidate:
+            lanes = self._tasks[candidate]
+            del self._tasks[ip]
+            return (candidate, lanes)
+        raise RuntimeError("No task to execute. Deadlock?")
+
+    # Handle an encountered merge instruction for the given lane.
+    def handle_convergence_header(self, lane: Lane, instruction: Instruction):
+        mergeTarget = self._module.get_bb_entry(instruction.merge_location())
+        for requirement in self._convergence_requirements:
+            if requirement.mergeTarget == mergeTarget:
+                requirement.impactedLanes.add(lane.tid())
+                return
+
+        continueTarget = None
+        if instruction.continue_location():
+            continueTarget = self._module.get_bb_entry(instruction.continue_location())
+        requirement = ConvergenceRequirement(
+            mergeTarget, continueTarget, set([lane.tid()])
+        )
+        self._convergence_requirements.append(requirement)
+
+    # Returns true if some instructions are scheduled for execution.
+    def _has_tasks(self):
+        return len(self._tasks) > 0
+
+    # Returns the index of the first active lane right now.
+    def get_first_active_lane_index(self) -> int:
+        return min(self._active_lane_indices)
+
+    # Broadcast the given value to all active lane registers'.
+    def broadcast_register(self, register, value) -> int:
+        for tid in self._active_lane_indices:
+            self._lanes[tid].set_register(register, value)
+
+    # Returns the function associated with 'name'.
+    # Calling this function with an invalid name is illegal.
+    def _get_function_from_name(self, name: str) -> Function:
+        register = self._module.getRegisterFromName(name)
+        assert register is not None
+        return self._module.get_function_entry(register)
+
+    # Run the wave on the function 'function_name' until all lanes are dead.
+    # If verbose is True, execution trace is printed.
+    # Returns the value returned by the function for each lane.
+    def run(self, function_name: str, verbose: bool = False) -> list[int]:
+        for t in self._lanes:
+            self._module.initialize(t)
+
+        function = self._get_function_from_name(function_name)
+        assert function is not None
+        for t in self._lanes:
+            t.do_call(function, "__shader_output__")
+
+        self._tasks[self._lanes[0].ip()] = self._lanes
+        while self._has_tasks():
+            ip, lanes = self._get_next_runnable_task()
+            self._active_lane_indices = set([x.tid() for x in lanes])
+            if verbose:
+                print(
+                    f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
+                )
+
+            for lane in lanes:
+                self._module.execute_one_instruction(lane, ip)
+                if not lane.running():
+                    continue
+
+                if lane.ip() in self._tasks:
+                    self._tasks[lane.ip()].append(lane)
+                else:
+                    self._tasks[lane.ip()] = [lane]
+
+            if verbose and ip.instruction().has_output_register():
+                register = ip.instruction().output_register()
+                print(
+                    f"   {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
+                )
+
+        output = []
+        for lane in self._lanes:
+            output.append(lane.get_register("__shader_output__"))
+        return output
+
+    def dump_register(self, register):
+        for lane in self._lanes:
+            print(
+                f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
+            )
+
+
+parser = argparse.ArgumentParser(
+    description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+parser.add_argument(
+    "-i", "--input", help="Text SPIR-V to read from", required=False, default="-"
+)
+parser.add_argument("-f", "--function", help="Function to execute")
+parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False)
+parser.add_argument(
+    "-e",
+    "--expects",
+    help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
+)
+parser.add_argument("-v", "--verbose", help="verbose", action="store_true")
+args = parser.parse_args()
+
+
+def load_instructions(filename):
+    if filename is None:
+        return []
+
+    if filename.lstrip().rstrip() != "-":
+        try:
+            with open(filename, "r") as f:
+                lines = f.read().split("\n")
+        except Exception:  # (FileNotFoundError, PermissionError):
+            return []
+    else:
+        lines = sys.stdin.readlines()
+
+    # Remove leading/trailing whitespaces.
+    lines = [x.rstrip().lstrip() for x in lines]
+    # Strip comments.
+    lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)]
+
+    instructions = []
+    for i in [Instruction(x) for x in lines]:
+        out = parseInstruction(i)
+        if out != None:
+            instructions.append(out)
+    return instructions
+
+
+def main():
+    if args.expects is None or not RE_EXPECTS.match(args.expects):
+        print("Invalid format for --expects/-e flag.", file=sys.stderr)
+        sys.exit(1)
+    if args.function is None:
+        print("Invalid format for --function/-f flag.", file=sys.stderr)
+        sys.exit(1)
+    try:
+        int(args.wave)
+    except ValueError:
+        print("Invalid format for --wave/-w flag.", file=sys.stderr)
+        sys.exit(1)
+
+    expected_results = [int(x.rstrip().lstrip()) for x in args.expects.split(",")]
+    wave_size = int(args.wave)
+    if len(expected_results) != wave_size:
+        print("Wave size != expected result array size", file=sys.stderr)
+        sys.exit(1)
+
+    instructions = load_instructions(args.input)
+    if len(instructions) == 0:
+        print("Invalid input. Expected a text SPIR-V module.")
+        sys.exit(1)
+
+    module = Module(instructions)
+    if args.verbose:
+        module.dump()
+        module.dump(args.function)
+
+    function_names = module.get_function_names()
+    if args.function not in function_names:
+        print(
+            f"'{args.function}' function not found. Known functions are:",
+            file=sys.stderr,
+        )
+        for name in function_names:
+            print(f" - {name}", file=sys.stderr)
+        sys.exit(1)
+
+    wave = Wave(module, wave_size)
+    results = wave.run(args.function, verbose=args.verbose)
+
+    if expected_results != results:
+        print("Expected != Observed", file=sys.stderr)
+        print(f"{expected_results} != {results}", file=sys.stderr)
+        sys.exit(1)
+    sys.exit(0)
+
+
+main()

>From 4c474b7770ff90766e5cc0d4b719f8fc6523996a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 14 Aug 2024 16:50:34 +0200
Subject: [PATCH 2/5] format and RUN line ordering
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/test/Other/spirv-sim/branch.spv | 2 +-
 llvm/test/Other/spirv-sim/loop.spv   | 2 +-
 llvm/utils/spirv-sim/spirv-sim.py    | 3 +++
 3 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Other/spirv-sim/branch.spv b/llvm/test/Other/spirv-sim/branch.spv
index 7ce0e7da3f058b..7ee0ebcad249dd 100644
--- a/llvm/test/Other/spirv-sim/branch.spv
+++ b/llvm/test/Other/spirv-sim/branch.spv
@@ -1,5 +1,5 @@
-; RUN: spirv-sim --function=simple --wave=3 --expects=5,6,6 -i %s
 ; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=3 --expects=5,6,6 -i %s
                OpCapability Shader
                OpCapability GroupNonUniform
                OpMemoryModel Logical GLSL450
diff --git a/llvm/test/Other/spirv-sim/loop.spv b/llvm/test/Other/spirv-sim/loop.spv
index c753ea2c149410..4fd0f1a7c96a31 100644
--- a/llvm/test/Other/spirv-sim/loop.spv
+++ b/llvm/test/Other/spirv-sim/loop.spv
@@ -1,5 +1,5 @@
-; RUN: spirv-sim --function=simple --wave=4 --expects=0,2,2,4 -i %s
 ; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=simple --wave=4 --expects=0,2,2,4 -i %s
                OpCapability Shader
                OpCapability GroupNonUniform
                OpMemoryModel Logical GLSL450
diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py
index e5b01a2445cf08..f729e117366b4b 100755
--- a/llvm/utils/spirv-sim/spirv-sim.py
+++ b/llvm/utils/spirv-sim/spirv-sim.py
@@ -11,6 +11,7 @@
 
 RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
 
+
 # Parse the SPIR-V instructions. Some instructions are ignored because
 # not required to simulate this module.
 # Instructions are to be implemented in instructions.py
@@ -44,6 +45,7 @@ def parseInstruction(i):
         )
     return Type(i.line)
 
+
 # Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
 # The delimiter is the first instruction of the next piece.
 # This function returns no empty pieces:
@@ -60,6 +62,7 @@ def splitInstructions(
         blocks[-1].append(instruction)
     return blocks
 
+
 # Defines a BasicBlock in the simulator.
 # Begins at an OpLabel, and ends with a control-flow instruction.
 class BasicBlock:

>From 538759d93134aebc436a519d1b5c4f0283a1105d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 14 Aug 2024 18:42:44 +0200
Subject: [PATCH 3/5] fix python path on windows
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/test/Other/spirv-sim/lit.local.cfg | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/llvm/test/Other/spirv-sim/lit.local.cfg b/llvm/test/Other/spirv-sim/lit.local.cfg
index d343a8f2bd9b09..67a8d9196f588a 100644
--- a/llvm/test/Other/spirv-sim/lit.local.cfg
+++ b/llvm/test/Other/spirv-sim/lit.local.cfg
@@ -1,2 +1,8 @@
 spirv_sim_root = os.path.join(config.llvm_src_root, "utils", "spirv-sim")
-config.substitutions.append(("spirv-sim", os.path.join(spirv_sim_root, "spirv-sim.py")))
+config.substitutions.append(
+  (
+    "spirv-sim",
+    "'%s' %s"
+    % (config.python_executable, os.path.join(spirv_sim_root, "spirv-sim.py")),
+  )
+)

>From 2842d11ade9c9da4b2a76956d43161c1fd695c67 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Thu, 22 Aug 2024 17:21:42 +0200
Subject: [PATCH 4/5] pr-feedback: annotation fix & cleanups
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/test/Other/spirv-sim/constant.spv |  36 ++++++
 llvm/utils/spirv-sim/instructions.py   |  49 +++-----
 llvm/utils/spirv-sim/spirv-sim.py      | 167 ++++++++++++++-----------
 3 files changed, 151 insertions(+), 101 deletions(-)
 create mode 100644 llvm/test/Other/spirv-sim/constant.spv

diff --git a/llvm/test/Other/spirv-sim/constant.spv b/llvm/test/Other/spirv-sim/constant.spv
new file mode 100644
index 00000000000000..1002427943a8d2
--- /dev/null
+++ b/llvm/test/Other/spirv-sim/constant.spv
@@ -0,0 +1,36 @@
+; RUN: %if spirv-tools %{ spirv-as %s -o - | spirv-val - %}
+; RUN: spirv-sim --function=a --wave=1 --expects=2 -i %s
+; RUN: spirv-sim --function=b --wave=1 --expects=1 -i %s
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 670
+               OpName %a "a"
+               OpName %b "b"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+         %s1 = OpTypeStruct %int %int %int
+         %s2 = OpTypeStruct %s1
+      %int_1 = OpConstant %int 1
+      %int_2 = OpConstant %int 2
+     %s1_1_2 = OpConstantComposite %s1 %int_1 %int_2 %int_1
+      %s2_s1 = OpConstantComposite %s2 %s1_1_2
+       %void = OpTypeVoid
+  %main_type = OpTypeFunction %void
+   %simple_type = OpTypeFunction %int
+       %main = OpFunction %void None %main_type
+      %entry = OpLabel
+               OpReturn
+               OpFunctionEnd
+          %a = OpFunction %int None %simple_type
+          %1 = OpLabel
+          %2 = OpCompositeExtract %int %s1_1_2 1
+               OpReturnValue %2
+               OpFunctionEnd
+          %b = OpFunction %int None %simple_type
+          %3 = OpLabel
+          %4 = OpCompositeExtract %int %s2_s1 0 2
+               OpReturnValue %4
+               OpFunctionEnd
+
diff --git a/llvm/utils/spirv-sim/instructions.py b/llvm/utils/spirv-sim/instructions.py
index 5f4ef831de0645..5ae0f826648ee4 100644
--- a/llvm/utils/spirv-sim/instructions.py
+++ b/llvm/utils/spirv-sim/instructions.py
@@ -1,16 +1,16 @@
+from typing import Optional
+
 # Base class for an instruction. To implement a basic instruction that doesn't
 # impact the control-flow, create a new class inheriting from this.
 class Instruction:
-    _result: str
+    # Contains the name of the output register, if any.
+    _result: Optional[str]
+    # Contains the instruction opcode.
+    _opcode: str
+    # Contains all the instruction operands, except result and opcode.
+    _operands: list[str]
 
     def __init__(self, line: str):
-        # Contains the name of the output register, if any.
-        self._result: str = None
-        # Contains the instruction opcode.
-        self._opcode: str = None
-        # Contains all the instruction operands, except result and opcode.
-        self._operands: list[str] = []
-
         self.line = line
         tokens = line.split()
         if len(tokens) > 1 and tokens[1] == "=":
@@ -38,7 +38,7 @@ def operands(self) -> list[str]:
     # Returns the instruction output register. Calling this function is
     # only allowed if has_output_register() is true.
     def output_register(self) -> str:
-        assert self.has_output_register()
+        assert self._result is not None
         return self._result
 
     # Returns true if this function has an output register. False otherwise.
@@ -132,20 +132,11 @@ def static_execution(self, lane):
 class OpConstantComposite(OpConstant):
     def static_execution(self, lane):
         result = []
-        length = self.get_register(self._operands[0])
         for op in self._operands[1:]:
-            result.append(self.get_register(op))
+            result.append(lane.get_register(op))
         lane.set_register(self._result, result)
 
 
-class OpConstantComposite(OpConstant):
-    def static_execution(self, vm, state):
-        output = []
-        for op in self._operands[1:]:
-            output.append(state.get_register(op))
-        state.set_register(self._result, output)
-
-
 # Control flow instructions
 class OpFunctionCall(Instruction):
     def _impl(self, module, lane):
@@ -217,7 +208,7 @@ def _impl(self, module, lane):
 
 
 # Convergence instructions
-class _MergeInstruction(Instruction):
+class MergeInstruction(Instruction):
     def merge_location(self):
         return self._operands[0]
 
@@ -228,11 +219,11 @@ def _impl(self, module, lane):
         lane.handle_convergence_header(self)
 
 
-class OpLoopMerge(_MergeInstruction):
+class OpLoopMerge(MergeInstruction):
     pass
 
 
-class OpSelectionMerge(_MergeInstruction):
+class OpSelectionMerge(MergeInstruction):
     pass
 
 
@@ -265,6 +256,13 @@ def _impl(self, module, lane):
             output.append(lane.get_register(op))
         lane.set_register(self._result, output)
 
+class OpCompositeExtract(Instruction):
+    def _impl(self, module, lane):
+        value = lane.get_register(self._operands[1])
+        output = value
+        for op in self._operands[2:]:
+          output = output[int(op)]
+        lane.set_register(self._result, output)
 
 class OpStore(Instruction):
     def _impl(self, module, lane):
@@ -303,13 +301,6 @@ def _impl(self, module, lane):
         lane.set_register(self._result, not LHS)
 
 
-class OpSGreaterThan(Instruction):
-    def _impl(self, module, lane):
-        LHS = lane.get_register(self._operands[1])
-        RHS = lane.get_register(self._operands[2])
-        lane.set_register(self._result, LHS > RHS)
-
-
 class _LessThan(Instruction):
     def _impl(self, module, lane):
         LHS = lane.get_register(self._operands[1])
diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py
index f729e117366b4b..dd2af79b258d91 100755
--- a/llvm/utils/spirv-sim/spirv-sim.py
+++ b/llvm/utils/spirv-sim/spirv-sim.py
@@ -1,13 +1,14 @@
 #!/usr/bin/env python3
 
-import fileinput
-import inspect
-from typing import Any
+from __future__ import annotations
 from dataclasses import dataclass
-import sys
 from instructions import *
+from typing import Any,Iterable,Callable,Optional,Tuple
 import argparse
+import fileinput
+import inspect
 import re
+import sys
 
 RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
 
@@ -24,6 +25,7 @@ def parseInstruction(i):
             "OpExtension",
             "OpSource",
             "OpTypeInt",
+            "OpTypeStruct",
             "OpTypeFloat",
             "OpTypeBool",
             "OpTypeVoid",
@@ -52,12 +54,10 @@ def parseInstruction(i):
 # - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
 #   with the delimiter and following instructions.
 # - if the first instruction is a delimiter, the first piece will begin with this delimiter.
-def splitInstructions(
-    splitType: type, instructions: list[Instruction]
-) -> list[list[Instruction]]:
-    blocks = [[]]
+def splitInstructions(splitType: type, instructions: Iterable[Instruction]) -> list[list[Instruction]]:
+    blocks : list[list[Instruction]] = [[]]
     for instruction in instructions:
-        if type(instruction) is splitType and len(blocks[-1]) > 0:
+        if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
             blocks.append([])
         blocks[-1].append(instruction)
     return blocks
@@ -66,8 +66,8 @@ def splitInstructions(
 # Defines a BasicBlock in the simulator.
 # Begins at an OpLabel, and ends with a control-flow instruction.
 class BasicBlock:
-    def __init__(self, instructions):
-        assert type(instructions[0]) is OpLabel
+    def __init__(self, instructions) -> None:
+        assert isinstance(instructions[0], OpLabel)
         # The name of the basic block, which is the register of the leading
         # OpLabel.
         self._name = instructions[0].output_register()
@@ -95,20 +95,19 @@ def dump(self):
 
 # Defines a Function in the simulator.
 class Function:
-    def __init__(self, instructions):
-        assert type(instructions[0]) is OpFunction
+    def __init__(self, instructions) -> None:
+        assert isinstance(instructions[0], OpFunction)
         # The name of the function (name of the register returned by OpFunction).
         self._name: str = instructions[0].output_register()
         # The list of basic blocks that belongs to this function.
         self._basic_blocks: list[BasicBlock] = []
         # The variables local to this function.
         self._variables: list[OpVariable] = [
-            x for x in instructions if type(x) is OpVariable
+            x for x in instructions if isinstance(x, OpVariable)
         ]
 
-        assert type(instructions[0]) is OpFunction
-        assert type(instructions[-1]) is OpFunctionEnd
-        body = filter(lambda x: type(x) != OpVariable, instructions[1:-1])
+        assert isinstance(instructions[-1], OpFunctionEnd)
+        body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1])
         for block in splitInstructions(OpLabel, body):
             self._basic_blocks.append(BasicBlock(block))
 
@@ -169,26 +168,34 @@ def __add__(self, value: int):
         bb = self.function[self.basic_block]
         assert len(bb) > self.instruction_index + value
         return InstructionPointer(
-            self.function, self.basic_block, self.instruction_index + 1
+            self.function, self.basic_block, self.instruction_index + value
         )
 
-
 # Defines a Lane in this simulator.
 class Lane:
-    def __init__(self, wave, tid):
-        # The registers known by this lane.
-        self._registers: dict[str, Any] = {}
-        # The current IP of this lane.
-        self._ip: InstructionPointer = None
-        # If this lane running.
-        self._running: bool = True
-        # The wave this lane belongs to.
-        self._wave: Wave = wave
-        # The callstack of this lane. Each tuple represents 1 call.
-        #   The first element is the IP the function will return to.
-        #   The second element is the callback to call to store the return value
-        #   into the correct register.
-        self._callstack: list[tuple(InstructionPointer, Callback)] = []
+    # The registers known by this lane.
+    _registers: dict[str, Any]
+    # The current IP of this lane.
+    _ip: Optional[InstructionPointer]
+    # If this lane running.
+    _running: bool
+    # The wave this lane belongs to.
+    _wave: Wave
+    # The callstack of this lane. Each tuple represents 1 call.
+    #   The first element is the IP the function will return to.
+    #   The second element is the callback to call to store the return value
+    #   into the correct register.
+    _callstack: list[Tuple[InstructionPointer, Callable[[Any], None] ]]
+
+    _previous_bb : Optional[BasicBlock]
+    _current_bb : Optional[BasicBlock]
+
+    def __init__(self, wave : Wave, tid : int) -> None:
+        self._registers = dict()
+        self._ip = None
+        self._running = True
+        self._wave = wave
+        self._callstack = []
 
         # The index of this lane in the wave.
         self._tid = tid
@@ -198,38 +205,39 @@ def __init__(self, wave, tid):
         self._current_bb = None
 
     # Returns the lane/thread ID of this lane in its wave.
-    def tid(self):
+    def tid(self) -> int:
         return self._tid
 
     # Returns true is this lane if the first by index in the current active tangle.
-    def is_first_active_lane(self):
+    def is_first_active_lane(self) -> bool:
         return self._tid == self._wave.get_first_active_lane_index()
 
     # Broadcast value into the registers of all active lanes.
-    def broadcast_register(self, register, value):
+    def broadcast_register(self, register : str, value : Any) -> None:
         self._wave.broadcast_register(register, value)
 
     # Returns the IP this lane is currently at.
-    def ip(self):
+    def ip(self) -> InstructionPointer:
+        assert self._ip is not None
         return self._ip
 
     # Returns true if this lane is running, false otherwise.
     # Running means not dead. An inactive lane is running.
-    def running(self):
+    def running(self) -> bool:
         return self._running
 
     # Set the register at "name" to "value" in this lane.
-    def set_register(self, name, value):
+    def set_register(self, name : str, value : Any) -> None:
         self._registers[name] = value
 
     # Get the value in register "name" in this lane.
     # if allow_undef is true, fetching an unknown register won't fail.
-    def get_register(self, name, allow_undef=False):
+    def get_register(self, name : str, allow_undef : bool = False) -> Optional[Any]:
         if allow_undef and name not in self._registers:
             return None
         return self._registers[name]
 
-    def set_ip(self, ip):
+    def set_ip(self, ip : InstructionPointer) -> None:
         if ip.bb() != self._current_bb:
             self._previous_bb = self._current_bb
             self._current_bb = ip.bb()
@@ -261,7 +269,13 @@ def do_return(self, value):
 
 # Represents the SPIR-V module in the simulator.
 class Module:
-    def __init__(self, instructions):
+    _functions : dict[str, Function]
+    _prolog : list[Instruction]
+    _globals : list[Instruction]
+    _name2reg : dict[str, str]
+    _reg2name : dict[str, str]
+
+    def __init__(self, instructions) -> None:
         chunks = splitInstructions(OpFunction, instructions)
 
         # The instructions located outside of all functions.
@@ -272,14 +286,14 @@ def __init__(self, instructions):
         self._globals = [
             x
             for x in instructions
-            if type(x) is OpVariable or issubclass(type(x), OpConstant)
+            if isinstance(x, OpVariable) or issubclass(type(x), OpConstant)
         ]
 
         # Helper dictionaries to get real names of registers, or registers by names.
         self._name2reg = {}
         self._reg2name = {}
         for instruction in instructions:
-            if type(instruction) is OpName:
+            if isinstance(instruction, OpName):
                 name = instruction.name()
                 reg = instruction.decoratedRegister()
                 self._name2reg[name] = reg
@@ -311,7 +325,7 @@ def initialize(self, lane):
 
         # Initialize builtins
         for instruction in self._prolog:
-            if type(instruction) is OpDecorate:
+            if isinstance(instruction, OpDecorate):
                 instruction.static_execution(lane)
 
     def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
@@ -340,10 +354,10 @@ def get_function_names(self):
         return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
 
     # Returns the global variables defined in this module.
-    def variables(self) -> iter:
+    def variables(self) -> Iterable:
         return [x.output_register() for x in self._globals]
 
-    def dump(self, function_name: str = None):
+    def dump(self, function_name: Optional[str] = None):
         print("Module:")
         print("  globals:")
         for instruction in self._globals:
@@ -371,35 +385,44 @@ def dump(self, function_name: str = None):
 @dataclass
 class ConvergenceRequirement:
     mergeTarget: InstructionPointer
-    continueTarget: InstructionPointer
+    continueTarget: Optional[InstructionPointer]
     impactedLanes: set[int]
 
+Task = dict[InstructionPointer, list[Lane]]
 
 # Defines a Lane group/Wave in the simulator.
 class Wave:
-    def __init__(self, module, wave_size: int):
+    # The module this wave will execute.
+    _module : Module
+    # The lanes this wave will be composed of.
+    _lanes : list[Lane]
+    # The instructions scheduled for execution.
+    _tasks : Task
+    # The actual requirements to comply with when executing instructions.
+    # e.g: the set of lanes required to merge before executing the merge block.
+    _convergence_requirements : list[ConvergenceRequirement]
+    # The indices of the active lanes for the current executing instruction.
+    _active_lane_indices : set[int]
+
+    def __init__(self, module, wave_size: int) -> None:
         assert wave_size > 0
-        # The module this wave will execute.
         self._module = module
-        # The lanes this wave will be composed of.
         self._lanes = []
+
         for i in range(wave_size):
             self._lanes.append(Lane(self, i))
 
-        # The instructions scheduled for execution.
-        self._tasks: dict(InstructionPointer, list[Lane]) = {}
-        # The actual requirements to comply with when executing instructions.
-        # e.g: the set of lanes required to merge before executing the merge block.
+        self._tasks = {}
         self._convergence_requirements = []
         # The indices of the active lanes for the current executing instruction.
         self._active_lane_indices = set()
 
     # Returns True if the given IP can be executed for the given list of lanes.
     def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
-        merged_lanes = set()
+        merged_lanes : set[int] = set()
         for lane in self._lanes:
             if not lane.running():
-                merged_lanes.add(lane)
+                merged_lanes.add(lane.tid())
 
         for requirement in self._convergence_requirements:
             # This task is not executing a merge or continue target.
@@ -435,7 +458,7 @@ def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
 
     # Returns the next task we can schedule. This must always return a task.
     # Calling this when all lanes are dead is invalid.
-    def _get_next_runnable_task(self):
+    def _get_next_runnable_task(self) -> Tuple[InstructionPointer, list[Lane]]:
         candidate = None
         for ip, lanes in self._tasks.items():
             if len(lanes) == 0:
@@ -451,7 +474,7 @@ def _get_next_runnable_task(self):
         raise RuntimeError("No task to execute. Deadlock?")
 
     # Handle an encountered merge instruction for the given lane.
-    def handle_convergence_header(self, lane: Lane, instruction: Instruction):
+    def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction):
         mergeTarget = self._module.get_bb_entry(instruction.merge_location())
         for requirement in self._convergence_requirements:
             if requirement.mergeTarget == mergeTarget:
@@ -467,7 +490,7 @@ def handle_convergence_header(self, lane: Lane, instruction: Instruction):
         self._convergence_requirements.append(requirement)
 
     # Returns true if some instructions are scheduled for execution.
-    def _has_tasks(self):
+    def _has_tasks(self) -> bool:
         return len(self._tasks) > 0
 
     # Returns the index of the first active lane right now.
@@ -475,13 +498,13 @@ def get_first_active_lane_index(self) -> int:
         return min(self._active_lane_indices)
 
     # Broadcast the given value to all active lane registers'.
-    def broadcast_register(self, register, value) -> int:
+    def broadcast_register(self, register : str, value : Any) -> None:
         for tid in self._active_lane_indices:
             self._lanes[tid].set_register(register, value)
 
-    # Returns the function associated with 'name'.
+    # Returns the entrypoint of the function associated with 'name'.
     # Calling this function with an invalid name is illegal.
-    def _get_function_from_name(self, name: str) -> Function:
+    def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
         register = self._module.getRegisterFromName(name)
         assert register is not None
         return self._module.get_function_entry(register)
@@ -489,14 +512,14 @@ def _get_function_from_name(self, name: str) -> Function:
     # Run the wave on the function 'function_name' until all lanes are dead.
     # If verbose is True, execution trace is printed.
     # Returns the value returned by the function for each lane.
-    def run(self, function_name: str, verbose: bool = False) -> list[int]:
+    def run(self, function_name : str, verbose: bool = False) -> list[Any]:
         for t in self._lanes:
             self._module.initialize(t)
 
-        function = self._get_function_from_name(function_name)
-        assert function is not None
+        entry_ip = self._get_function_entry_from_name(function_name)
+        assert entry_ip is not None
         for t in self._lanes:
-            t.do_call(function, "__shader_output__")
+            t.do_call(entry_ip, "__shader_output__")
 
         self._tasks[self._lanes[0].ip()] = self._lanes
         while self._has_tasks():
@@ -528,7 +551,7 @@ def run(self, function_name: str, verbose: bool = False) -> list[int]:
             output.append(lane.get_register("__shader_output__"))
         return output
 
-    def dump_register(self, register):
+    def dump_register(self, register : str) -> None:
         for lane in self._lanes:
             print(
                 f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
@@ -552,11 +575,11 @@ def dump_register(self, register):
 args = parser.parse_args()
 
 
-def load_instructions(filename):
+def load_instructions(filename : str):
     if filename is None:
         return []
 
-    if filename.lstrip().rstrip() != "-":
+    if filename.strip() != "-":
         try:
             with open(filename, "r") as f:
                 lines = f.read().split("\n")
@@ -566,7 +589,7 @@ def load_instructions(filename):
         lines = sys.stdin.readlines()
 
     # Remove leading/trailing whitespaces.
-    lines = [x.rstrip().lstrip() for x in lines]
+    lines = [x.strip() for x in lines]
     # Strip comments.
     lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)]
 
@@ -591,7 +614,7 @@ def main():
         print("Invalid format for --wave/-w flag.", file=sys.stderr)
         sys.exit(1)
 
-    expected_results = [int(x.rstrip().lstrip()) for x in args.expects.split(",")]
+    expected_results = [int(x.strip()) for x in args.expects.split(",")]
     wave_size = int(args.wave)
     if len(expected_results) != wave_size:
         print("Wave size != expected result array size", file=sys.stderr)

>From 6c93b7e67faa4578df7a3beab798aa068ad76359 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Thu, 22 Aug 2024 17:32:14 +0200
Subject: [PATCH 5/5] format
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/utils/spirv-sim/spirv-sim.py | 57 +++++++++++++++++--------------
 1 file changed, 31 insertions(+), 26 deletions(-)

diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py
index dd2af79b258d91..74b9cb3dc79785 100755
--- a/llvm/utils/spirv-sim/spirv-sim.py
+++ b/llvm/utils/spirv-sim/spirv-sim.py
@@ -3,7 +3,7 @@
 from __future__ import annotations
 from dataclasses import dataclass
 from instructions import *
-from typing import Any,Iterable,Callable,Optional,Tuple
+from typing import Any, Iterable, Callable, Optional, Tuple
 import argparse
 import fileinput
 import inspect
@@ -54,8 +54,10 @@ def parseInstruction(i):
 # - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
 #   with the delimiter and following instructions.
 # - if the first instruction is a delimiter, the first piece will begin with this delimiter.
-def splitInstructions(splitType: type, instructions: Iterable[Instruction]) -> list[list[Instruction]]:
-    blocks : list[list[Instruction]] = [[]]
+def splitInstructions(
+    splitType: type, instructions: Iterable[Instruction]
+) -> list[list[Instruction]]:
+    blocks: list[list[Instruction]] = [[]]
     for instruction in instructions:
         if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
             blocks.append([])
@@ -171,6 +173,7 @@ def __add__(self, value: int):
             self.function, self.basic_block, self.instruction_index + value
         )
 
+
 # Defines a Lane in this simulator.
 class Lane:
     # The registers known by this lane.
@@ -185,12 +188,12 @@ class Lane:
     #   The first element is the IP the function will return to.
     #   The second element is the callback to call to store the return value
     #   into the correct register.
-    _callstack: list[Tuple[InstructionPointer, Callable[[Any], None] ]]
+    _callstack: list[Tuple[InstructionPointer, Callable[[Any], None]]]
 
-    _previous_bb : Optional[BasicBlock]
-    _current_bb : Optional[BasicBlock]
+    _previous_bb: Optional[BasicBlock]
+    _current_bb: Optional[BasicBlock]
 
-    def __init__(self, wave : Wave, tid : int) -> None:
+    def __init__(self, wave: Wave, tid: int) -> None:
         self._registers = dict()
         self._ip = None
         self._running = True
@@ -213,7 +216,7 @@ def is_first_active_lane(self) -> bool:
         return self._tid == self._wave.get_first_active_lane_index()
 
     # Broadcast value into the registers of all active lanes.
-    def broadcast_register(self, register : str, value : Any) -> None:
+    def broadcast_register(self, register: str, value: Any) -> None:
         self._wave.broadcast_register(register, value)
 
     # Returns the IP this lane is currently at.
@@ -227,17 +230,17 @@ def running(self) -> bool:
         return self._running
 
     # Set the register at "name" to "value" in this lane.
-    def set_register(self, name : str, value : Any) -> None:
+    def set_register(self, name: str, value: Any) -> None:
         self._registers[name] = value
 
     # Get the value in register "name" in this lane.
     # if allow_undef is true, fetching an unknown register won't fail.
-    def get_register(self, name : str, allow_undef : bool = False) -> Optional[Any]:
+    def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
         if allow_undef and name not in self._registers:
             return None
         return self._registers[name]
 
-    def set_ip(self, ip : InstructionPointer) -> None:
+    def set_ip(self, ip: InstructionPointer) -> None:
         if ip.bb() != self._current_bb:
             self._previous_bb = self._current_bb
             self._current_bb = ip.bb()
@@ -269,11 +272,11 @@ def do_return(self, value):
 
 # Represents the SPIR-V module in the simulator.
 class Module:
-    _functions : dict[str, Function]
-    _prolog : list[Instruction]
-    _globals : list[Instruction]
-    _name2reg : dict[str, str]
-    _reg2name : dict[str, str]
+    _functions: dict[str, Function]
+    _prolog: list[Instruction]
+    _globals: list[Instruction]
+    _name2reg: dict[str, str]
+    _reg2name: dict[str, str]
 
     def __init__(self, instructions) -> None:
         chunks = splitInstructions(OpFunction, instructions)
@@ -388,21 +391,23 @@ class ConvergenceRequirement:
     continueTarget: Optional[InstructionPointer]
     impactedLanes: set[int]
 
+
 Task = dict[InstructionPointer, list[Lane]]
 
+
 # Defines a Lane group/Wave in the simulator.
 class Wave:
     # The module this wave will execute.
-    _module : Module
+    _module: Module
     # The lanes this wave will be composed of.
-    _lanes : list[Lane]
+    _lanes: list[Lane]
     # The instructions scheduled for execution.
-    _tasks : Task
+    _tasks: Task
     # The actual requirements to comply with when executing instructions.
     # e.g: the set of lanes required to merge before executing the merge block.
-    _convergence_requirements : list[ConvergenceRequirement]
+    _convergence_requirements: list[ConvergenceRequirement]
     # The indices of the active lanes for the current executing instruction.
-    _active_lane_indices : set[int]
+    _active_lane_indices: set[int]
 
     def __init__(self, module, wave_size: int) -> None:
         assert wave_size > 0
@@ -419,7 +424,7 @@ def __init__(self, module, wave_size: int) -> None:
 
     # Returns True if the given IP can be executed for the given list of lanes.
     def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
-        merged_lanes : set[int] = set()
+        merged_lanes: set[int] = set()
         for lane in self._lanes:
             if not lane.running():
                 merged_lanes.add(lane.tid())
@@ -498,7 +503,7 @@ def get_first_active_lane_index(self) -> int:
         return min(self._active_lane_indices)
 
     # Broadcast the given value to all active lane registers'.
-    def broadcast_register(self, register : str, value : Any) -> None:
+    def broadcast_register(self, register: str, value: Any) -> None:
         for tid in self._active_lane_indices:
             self._lanes[tid].set_register(register, value)
 
@@ -512,7 +517,7 @@ def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
     # Run the wave on the function 'function_name' until all lanes are dead.
     # If verbose is True, execution trace is printed.
     # Returns the value returned by the function for each lane.
-    def run(self, function_name : str, verbose: bool = False) -> list[Any]:
+    def run(self, function_name: str, verbose: bool = False) -> list[Any]:
         for t in self._lanes:
             self._module.initialize(t)
 
@@ -551,7 +556,7 @@ def run(self, function_name : str, verbose: bool = False) -> list[Any]:
             output.append(lane.get_register("__shader_output__"))
         return output
 
-    def dump_register(self, register : str) -> None:
+    def dump_register(self, register: str) -> None:
         for lane in self._lanes:
             print(
                 f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
@@ -575,7 +580,7 @@ def dump_register(self, register : str) -> None:
 args = parser.parse_args()
 
 
-def load_instructions(filename : str):
+def load_instructions(filename: str):
     if filename is None:
         return []
 



More information about the llvm-commits mailing list