[llvm] [Utils][SPIR-V] Adding spirv-sim to LLVM (PR #104020)

Ilia Diachkov via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 05:07:05 PDT 2024


Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/104020 at github.com>


================
@@ -0,0 +1,658 @@
+#!/usr/bin/env python3
+
+from __future__ import annotations
+from dataclasses import dataclass
+from instructions import *
+from typing import Any, Iterable, Callable, Optional, Tuple
+import argparse
+import fileinput
+import inspect
+import re
+import sys
+
+RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
+
+
+# Parse the SPIR-V instructions. Some instructions are ignored because
+# not required to simulate this module.
+# Instructions are to be implemented in instructions.py
+def parseInstruction(i):
+    IGNORED = set(
+        [
+            "OpCapability",
+            "OpMemoryModel",
+            "OpExecutionMode",
+            "OpExtension",
+            "OpSource",
+            "OpTypeInt",
+            "OpTypeStruct",
+            "OpTypeFloat",
+            "OpTypeBool",
+            "OpTypeVoid",
+            "OpTypeFunction",
+            "OpTypePointer",
+            "OpTypeArray",
+        ]
+    )
+    if i.opcode() in IGNORED:
+        return None
+
+    try:
+        Type = getattr(sys.modules["instructions"], i.opcode())
+    except AttributeError:
+        raise RuntimeError(f"Unsupported instruction {i}")
+    if not inspect.isclass(Type):
+        raise RuntimeError(
+            f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
+        )
+    return Type(i.line)
+
+
+# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
+# The delimiter is the first instruction of the next piece.
+# This function returns no empty pieces:
+# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
+#   with the delimiter and following instructions.
+# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
+def splitInstructions(
+    splitType: type, instructions: Iterable[Instruction]
+) -> list[list[Instruction]]:
+    blocks: list[list[Instruction]] = [[]]
+    for instruction in instructions:
+        if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
+            blocks.append([])
+        blocks[-1].append(instruction)
+    return blocks
+
+
+# Defines a BasicBlock in the simulator.
+# Begins at an OpLabel, and ends with a control-flow instruction.
+class BasicBlock:
+    def __init__(self, instructions) -> None:
+        assert isinstance(instructions[0], OpLabel)
+        # The name of the basic block, which is the register of the leading
+        # OpLabel.
+        self._name = instructions[0].output_register()
+        # The list of instructions belonging to this block.
+        self._instructions = instructions[1:]
+
+    # Returns the name of this basic block.
+    def name(self):
+        return self._name
+
+    # Returns the instruction at index in this basic block.
+    def __getitem__(self, index: int) -> Instruction:
+        return self._instructions[index]
+
+    # Returns the number of instructions in this basic block, excluding the
+    # leading OpLabel.
+    def __len__(self):
+        return len(self._instructions)
+
+    def dump(self):
+        print(f"        {self._name}:")
+        for instruction in self._instructions:
+            print(f"        {instruction}")
+
+
+# Defines a Function in the simulator.
+class Function:
+    def __init__(self, instructions) -> None:
+        assert isinstance(instructions[0], OpFunction)
+        # The name of the function (name of the register returned by OpFunction).
+        self._name: str = instructions[0].output_register()
+        # The list of basic blocks that belongs to this function.
+        self._basic_blocks: list[BasicBlock] = []
+        # The variables local to this function.
+        self._variables: list[OpVariable] = [
+            x for x in instructions if isinstance(x, OpVariable)
+        ]
+
+        assert isinstance(instructions[-1], OpFunctionEnd)
+        body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1])
+        for block in splitInstructions(OpLabel, body):
+            self._basic_blocks.append(BasicBlock(block))
+
+    # Returns the name of this function.
+    def name(self) -> str:
+        return self._name
+
+    # Returns the basic block at index in this function.
+    def __getitem__(self, index: int) -> BasicBlock:
+        return self._basic_blocks[index]
+
+    # Returns the index of the basic block with the given name if found,
+    # -1 otherwise.
+    def get_bb_index(self, name) -> int:
+        for i in range(len(self._basic_blocks)):
+            if self._basic_blocks[i].name() == name:
+                return i
+        return -1
+
+    def dump(self):
+        print("      Variables:")
+        for var in self._variables:
+            print(f"        {var}")
+        print("      Blocks:")
+        for bb in self._basic_blocks:
+            bb.dump()
+
+
+# Represents an instruction pointer in the simulator.
+ at dataclass
+class InstructionPointer:
+    # The current function the IP points to.
+    function: Function
+    # The basic block index in function IP points to.
+    basic_block: int
+    # The instruction in basic_block IP points to.
+    instruction_index: int
+
+    def __str__(self):
+        bb = self.function[self.basic_block]
+        i = bb[self.instruction_index]
+        return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
+
+    def __hash__(self):
+        return hash((self.function.name(), self.basic_block, self.instruction_index))
+
+    # Returns the basic block IP points to.
+    def bb(self) -> BasicBlock:
+        return self.function[self.basic_block]
+
+    # Returns the instruction IP points to.
+    def instruction(self):
+        return self.function[self.basic_block][self.instruction_index]
+
+    # Increment IP by 1. This only works inside a basic-block boundary.
+    # Incrementing IP when at the boundary of a basic block will fail.
+    def __add__(self, value: int):
+        bb = self.function[self.basic_block]
+        assert len(bb) > self.instruction_index + value
+        return InstructionPointer(
+            self.function, self.basic_block, self.instruction_index + value
+        )
+
+
+# Defines a Lane in this simulator.
+class Lane:
+    # The registers known by this lane.
+    _registers: dict[str, Any]
+    # The current IP of this lane.
+    _ip: Optional[InstructionPointer]
+    # If this lane running.
+    _running: bool
+    # The wave this lane belongs to.
+    _wave: Wave
+    # The callstack of this lane. Each tuple represents 1 call.
+    #   The first element is the IP the function will return to.
+    #   The second element is the callback to call to store the return value
+    #   into the correct register.
+    _callstack: list[Tuple[InstructionPointer, Callable[[Any], None]]]
+
+    _previous_bb: Optional[BasicBlock]
+    _current_bb: Optional[BasicBlock]
+
+    def __init__(self, wave: Wave, tid: int) -> None:
+        self._registers = dict()
+        self._ip = None
+        self._running = True
+        self._wave = wave
+        self._callstack = []
+
+        # The index of this lane in the wave.
+        self._tid = tid
+        # The last BB this lane was executing into.
+        self._previous_bb = None
+        # The current BB this lane is executing into.
+        self._current_bb = None
+
+    # Returns the lane/thread ID of this lane in its wave.
+    def tid(self) -> int:
+        return self._tid
+
+    # Returns true is this lane if the first by index in the current active tangle.
+    def is_first_active_lane(self) -> bool:
+        return self._tid == self._wave.get_first_active_lane_index()
+
+    # Broadcast value into the registers of all active lanes.
+    def broadcast_register(self, register: str, value: Any) -> None:
+        self._wave.broadcast_register(register, value)
+
+    # Returns the IP this lane is currently at.
+    def ip(self) -> InstructionPointer:
+        assert self._ip is not None
+        return self._ip
+
+    # Returns true if this lane is running, false otherwise.
+    # Running means not dead. An inactive lane is running.
+    def running(self) -> bool:
+        return self._running
+
+    # Set the register at "name" to "value" in this lane.
+    def set_register(self, name: str, value: Any) -> None:
+        self._registers[name] = value
+
+    # Get the value in register "name" in this lane.
+    # if allow_undef is true, fetching an unknown register won't fail.
+    def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
+        if allow_undef and name not in self._registers:
+            return None
+        return self._registers[name]
+
+    def set_ip(self, ip: InstructionPointer) -> None:
+        if ip.bb() != self._current_bb:
+            self._previous_bb = self._current_bb
+            self._current_bb = ip.bb()
+        self._ip = ip
+
+    def get_previous_bb_name(self):
+        return self._previous_bb.name()
+
+    def handle_convergence_header(self, instruction):
+        self._wave.handle_convergence_header(self, instruction)
+
+    def do_call(self, ip, output_register):
+        return_ip = None if self._ip is None else self._ip + 1
+        self._callstack.append(
+            (return_ip, lambda value: self.set_register(output_register, value))
+        )
+        self.set_ip(ip)
+
+    def do_return(self, value):
+        ip, callback = self._callstack[-1]
+        self._callstack.pop()
+
+        callback(value)
+        if len(self._callstack) == 0:
+            self._running = False
+        else:
+            self.set_ip(ip)
+
+
+# Represents the SPIR-V module in the simulator.
+class Module:
+    _functions: dict[str, Function]
+    _prolog: list[Instruction]
+    _globals: list[Instruction]
+    _name2reg: dict[str, str]
+    _reg2name: dict[str, str]
+
+    def __init__(self, instructions) -> None:
+        chunks = splitInstructions(OpFunction, instructions)
+
+        # The instructions located outside of all functions.
+        self._prolog = chunks[0]
+        # The functions in this module.
+        self._functions = {}
+        # Global variables in this module.
+        self._globals = [
+            x
+            for x in instructions
+            if isinstance(x, OpVariable) or issubclass(type(x), OpConstant)
+        ]
+
+        # Helper dictionaries to get real names of registers, or registers by names.
+        self._name2reg = {}
+        self._reg2name = {}
+        for instruction in instructions:
+            if isinstance(instruction, OpName):
+                name = instruction.name()
+                reg = instruction.decoratedRegister()
+                self._name2reg[name] = reg
+                self._reg2name[reg] = name
+
+        for chunk in chunks[1:]:
+            function = Function(chunk)
+            assert function.name() not in self._functions
+            self._functions[function.name()] = function
+
+    # Returns the register matching "name" if any, None otherwise.
+    # This assumes names are unique.
+    def getRegisterFromName(self, name):
+        if name in self._name2reg:
+            return self._name2reg[name]
+        return None
+
+    # Returns the name given to "register" if any, None otherwise.
+    def getNameFromRegister(self, register):
+        if register in self._reg2name:
+            return self._reg2name[register]
+        return None
+
+    # Initialize the module before wave execution begins.
+    # See Instruction::static_execution for more details.
+    def initialize(self, lane):
+        for instruction in self._globals:
+            instruction.static_execution(lane)
+
+        # Initialize builtins
+        for instruction in self._prolog:
+            if isinstance(instruction, OpDecorate):
+                instruction.static_execution(lane)
+
+    def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
+        ip.instruction().runtime_execution(self, lane)
+
+    # Returns the first valid IP for the function defined by the given register.
+    # Calling this with a register not returned by OpFunction is illegal.
+    def get_function_entry(self, register: str) -> InstructionPointer:
+        if register not in self._functions:
+            raise RuntimeError(f"Function defining {register} not found.")
+        return InstructionPointer(self._functions[register], 0, 0)
+
+    # Returns the first valid IP for the basic block defined by register.
+    # Calling this with a register not returned by an OpLabel is illegal.
+    def get_bb_entry(self, register: str) -> InstructionPointer:
+        for name, function in self._functions.items():
+            index = function.get_bb_index(register)
+            if index != -1:
+                return InstructionPointer(function, index, 0)
+        raise RuntimeError(f"Instruction defining {register} not found.")
+
+    # Returns the list of function names in this module.
+    # If an OpName exists for this function, returns the pretty name, else
+    # returns the register name.
+    def get_function_names(self):
+        return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
+
+    # Returns the global variables defined in this module.
+    def variables(self) -> Iterable:
+        return [x.output_register() for x in self._globals]
+
+    def dump(self, function_name: Optional[str] = None):
+        print("Module:")
+        print("  globals:")
+        for instruction in self._globals:
+            print(f"    {instruction}")
+
+        if function_name is None:
+            print("  functions:")
+            for register, function in self._functions.items():
+                name = self.getNameFromRegister(register)
+                print(f"  Function {register} ({name})")
+                function.dump()
+            return
+
+        register = self.getRegisterFromName(function_name)
+        print(f"  function {register} ({function_name}):")
+        if register is not None:
+            self._functions[register].dump()
+        else:
+            print(f"    error: cannot find function.")
+
+
+# Defines a convergence requirement for the simulation:
+# A list of lanes impacted by a merge and possibly the associated
+# continue target.
+ at dataclass
+class ConvergenceRequirement:
+    mergeTarget: InstructionPointer
+    continueTarget: Optional[InstructionPointer]
+    impactedLanes: set[int]
+
+
+Task = dict[InstructionPointer, list[Lane]]
+
+
+# Defines a Lane group/Wave in the simulator.
+class Wave:
+    # The module this wave will execute.
+    _module: Module
+    # The lanes this wave will be composed of.
+    _lanes: list[Lane]
+    # The instructions scheduled for execution.
+    _tasks: Task
+    # The actual requirements to comply with when executing instructions.
+    # e.g: the set of lanes required to merge before executing the merge block.
+    _convergence_requirements: list[ConvergenceRequirement]
+    # The indices of the active lanes for the current executing instruction.
+    _active_lane_indices: set[int]
+
+    def __init__(self, module, wave_size: int) -> None:
+        assert wave_size > 0
+        self._module = module
+        self._lanes = []
+
+        for i in range(wave_size):
+            self._lanes.append(Lane(self, i))
+
+        self._tasks = {}
+        self._convergence_requirements = []
+        # The indices of the active lanes for the current executing instruction.
+        self._active_lane_indices = set()
+
+    # Returns True if the given IP can be executed for the given list of lanes.
+    def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
+        merged_lanes: set[int] = set()
+        for lane in self._lanes:
+            if not lane.running():
+                merged_lanes.add(lane.tid())
+
+        for requirement in self._convergence_requirements:
+            # This task is not executing a merge or continue target.
+            # Adding all lanes at those points into the ignore list.
+            if requirement.mergeTarget != ip and requirement.continueTarget != ip:
+                for tid in requirement.impactedLanes:
+                    if self._lanes[tid].ip() == requirement.mergeTarget:
+                        merged_lanes.add(tid)
+                    if self._lanes[tid].ip() == requirement.continueTarget:
+                        merged_lanes.add(tid)
+                continue
+
+            # This task is executing the current requirement continue/merge
+            # target.
+            for tid in requirement.impactedLanes:
+                lane = self._lanes[tid]
+                if not lane.running():
+                    continue
+
+                if lane.tid() in merged_lanes:
+                    continue
+
+                if ip == requirement.mergeTarget:
+                    if lane.ip() != requirement.mergeTarget:
+                        return False
+                else:
+                    if (
+                        lane.ip() != requirement.mergeTarget
+                        and lane.ip() != requirement.continueTarget
+                    ):
+                        return False
+        return True
+
+    # Returns the next task we can schedule. This must always return a task.
+    # Calling this when all lanes are dead is invalid.
+    def _get_next_runnable_task(self) -> Tuple[InstructionPointer, list[Lane]]:
+        candidate = None
+        for ip, lanes in self._tasks.items():
+            if len(lanes) == 0:
+                continue
+            if self._is_task_candidate(ip, lanes):
+                candidate = ip
+                break
+
+        if candidate:
+            lanes = self._tasks[candidate]
+            del self._tasks[ip]
+            return (candidate, lanes)
+        raise RuntimeError("No task to execute. Deadlock?")
+
+    # Handle an encountered merge instruction for the given lane.
+    def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction):
+        mergeTarget = self._module.get_bb_entry(instruction.merge_location())
+        for requirement in self._convergence_requirements:
+            if requirement.mergeTarget == mergeTarget:
+                requirement.impactedLanes.add(lane.tid())
+                return
+
+        continueTarget = None
+        if instruction.continue_location():
+            continueTarget = self._module.get_bb_entry(instruction.continue_location())
+        requirement = ConvergenceRequirement(
+            mergeTarget, continueTarget, set([lane.tid()])
+        )
+        self._convergence_requirements.append(requirement)
+
+    # Returns true if some instructions are scheduled for execution.
+    def _has_tasks(self) -> bool:
+        return len(self._tasks) > 0
+
+    # Returns the index of the first active lane right now.
+    def get_first_active_lane_index(self) -> int:
+        return min(self._active_lane_indices)
+
+    # Broadcast the given value to all active lane registers'.
----------------
iliya-diyachkov wrote:

"registers'" -> "registers"?

https://github.com/llvm/llvm-project/pull/104020


More information about the llvm-commits mailing list