[llvm] [Utils][SPIR-V] Adding spirv-sim to LLVM (PR #104020)
Michal Paszkowski via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 23:02:00 PDT 2024
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/104020 at github.com>
================
@@ -0,0 +1,630 @@
+#!/usr/bin/env python3
+
+import fileinput
+import inspect
+from typing import Any
+from dataclasses import dataclass
+import sys
+from instructions import *
+import argparse
+import re
+
+RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
+
+
+# Parse the SPIR-V instructions. Some instructions are ignored because
+# not required to simulate this module.
+# Instructions are to be implemented in instructions.py
+def parseInstruction(i):
+ IGNORED = set(
+ [
+ "OpCapability",
+ "OpMemoryModel",
+ "OpExecutionMode",
+ "OpExtension",
+ "OpSource",
+ "OpTypeInt",
+ "OpTypeFloat",
+ "OpTypeBool",
+ "OpTypeVoid",
+ "OpTypeFunction",
+ "OpTypePointer",
+ "OpTypeArray",
+ ]
+ )
+ if i.opcode() in IGNORED:
+ return None
+
+ try:
+ Type = getattr(sys.modules["instructions"], i.opcode())
+ except AttributeError:
+ raise RuntimeError(f"Unsupported instruction {i}")
+ if not inspect.isclass(Type):
+ raise RuntimeError(
+ f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
+ )
+ return Type(i.line)
+
+
+# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
+# The delimiter is the first instruction of the next piece.
+# This function returns no empty pieces:
+# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
+# with the delimiter and following instructions.
+# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
+def splitInstructions(
+ splitType: type, instructions: list[Instruction]
+) -> list[list[Instruction]]:
+ blocks = [[]]
+ for instruction in instructions:
+ if type(instruction) is splitType and len(blocks[-1]) > 0:
+ blocks.append([])
+ blocks[-1].append(instruction)
+ return blocks
+
+
+# Defines a BasicBlock in the simulator.
+# Begins at an OpLabel, and ends with a control-flow instruction.
+class BasicBlock:
+ def __init__(self, instructions):
+ assert type(instructions[0]) is OpLabel
+ # The name of the basic block, which is the register of the leading
+ # OpLabel.
+ self._name = instructions[0].output_register()
+ # The list of instructions belonging to this block.
+ self._instructions = instructions[1:]
+
+ # Returns the name of this basic block.
+ def name(self):
+ return self._name
+
+ # Returns the instruction at index in this basic block.
+ def __getitem__(self, index: int) -> Instruction:
+ return self._instructions[index]
+
+ # Returns the number of instructions in this basic block, excluding the
+ # leading OpLabel.
+ def __len__(self):
+ return len(self._instructions)
+
+ def dump(self):
+ print(f" {self._name}:")
+ for instruction in self._instructions:
+ print(f" {instruction}")
+
+
+# Defines a Function in the simulator.
+class Function:
+ def __init__(self, instructions):
+ assert type(instructions[0]) is OpFunction
+ # The name of the function (name of the register returned by OpFunction).
+ self._name: str = instructions[0].output_register()
+ # The list of basic blocks that belongs to this function.
+ self._basic_blocks: list[BasicBlock] = []
+ # The variables local to this function.
+ self._variables: list[OpVariable] = [
+ x for x in instructions if type(x) is OpVariable
+ ]
+
+ assert type(instructions[0]) is OpFunction
+ assert type(instructions[-1]) is OpFunctionEnd
+ body = filter(lambda x: type(x) != OpVariable, instructions[1:-1])
+ for block in splitInstructions(OpLabel, body):
+ self._basic_blocks.append(BasicBlock(block))
+
+ # Returns the name of this function.
+ def name(self) -> str:
+ return self._name
+
+ # Returns the basic block at index in this function.
+ def __getitem__(self, index: int) -> BasicBlock:
+ return self._basic_blocks[index]
+
+ # Returns the index of the basic block with the given name if found,
+ # -1 otherwise.
+ def get_bb_index(self, name) -> int:
+ for i in range(len(self._basic_blocks)):
+ if self._basic_blocks[i].name() == name:
+ return i
+ return -1
+
+ def dump(self):
+ print(" Variables:")
+ for var in self._variables:
+ print(f" {var}")
+ print(" Blocks:")
+ for bb in self._basic_blocks:
+ bb.dump()
+
+
+# Represents an instruction pointer in the simulator.
+ at dataclass
+class InstructionPointer:
+ # The current function the IP points to.
+ function: Function
+ # The basic block index in function IP points to.
+ basic_block: int
+ # The instruction in basic_block IP points to.
+ instruction_index: int
+
+ def __str__(self):
+ bb = self.function[self.basic_block]
+ i = bb[self.instruction_index]
+ return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
+
+ def __hash__(self):
+ return hash((self.function.name(), self.basic_block, self.instruction_index))
+
+ # Returns the basic block IP points to.
+ def bb(self) -> BasicBlock:
+ return self.function[self.basic_block]
+
+ # Returns the instruction IP points to.
+ def instruction(self):
+ return self.function[self.basic_block][self.instruction_index]
+
+ # Increment IP by 1. This only works inside a basic-block boundary.
+ # Incrementing IP when at the boundary of a basic block will fail.
+ def __add__(self, value: int):
+ bb = self.function[self.basic_block]
+ assert len(bb) > self.instruction_index + value
+ return InstructionPointer(
+ self.function, self.basic_block, self.instruction_index + 1
+ )
+
+
+# Defines a Lane in this simulator.
+class Lane:
+ def __init__(self, wave, tid):
+ # The registers known by this lane.
+ self._registers: dict[str, Any] = {}
+ # The current IP of this lane.
+ self._ip: InstructionPointer = None
+ # If this lane running.
+ self._running: bool = True
+ # The wave this lane belongs to.
+ self._wave: Wave = wave
+ # The callstack of this lane. Each tuple represents 1 call.
+ # The first element is the IP the function will return to.
+ # The second element is the callback to call to store the return value
+ # into the correct register.
+ self._callstack: list[tuple(InstructionPointer, Callback)] = []
+
+ # The index of this lane in the wave.
+ self._tid = tid
+ # The last BB this lane was executing into.
+ self._previous_bb = None
+ # The current BB this lane is executing into.
+ self._current_bb = None
+
+ # Returns the lane/thread ID of this lane in its wave.
+ def tid(self):
+ return self._tid
+
+ # Returns true is this lane if the first by index in the current active tangle.
+ def is_first_active_lane(self):
+ return self._tid == self._wave.get_first_active_lane_index()
+
+ # Broadcast value into the registers of all active lanes.
+ def broadcast_register(self, register, value):
+ self._wave.broadcast_register(register, value)
+
+ # Returns the IP this lane is currently at.
+ def ip(self):
+ return self._ip
+
+ # Returns true if this lane is running, false otherwise.
+ # Running means not dead. An inactive lane is running.
+ def running(self):
+ return self._running
+
+ # Set the register at "name" to "value" in this lane.
+ def set_register(self, name, value):
+ self._registers[name] = value
+
+ # Get the value in register "name" in this lane.
+ # if allow_undef is true, fetching an unknown register won't fail.
+ def get_register(self, name, allow_undef=False):
+ if allow_undef and name not in self._registers:
+ return None
+ return self._registers[name]
+
+ def set_ip(self, ip):
+ if ip.bb() != self._current_bb:
+ self._previous_bb = self._current_bb
+ self._current_bb = ip.bb()
+ self._ip = ip
+
+ def get_previous_bb_name(self):
+ return self._previous_bb.name()
+
+ def handle_convergence_header(self, instruction):
+ self._wave.handle_convergence_header(self, instruction)
+
+ def do_call(self, ip, output_register):
+ return_ip = None if self._ip is None else self._ip + 1
+ self._callstack.append(
+ (return_ip, lambda value: self.set_register(output_register, value))
+ )
+ self.set_ip(ip)
+
+ def do_return(self, value):
+ ip, callback = self._callstack[-1]
+ self._callstack.pop()
+
+ callback(value)
+ if len(self._callstack) == 0:
+ self._running = False
+ else:
+ self.set_ip(ip)
+
+
+# Represents the SPIR-V module in the simulator.
+class Module:
+ def __init__(self, instructions):
+ chunks = splitInstructions(OpFunction, instructions)
+
+ # The instructions located outside of all functions.
+ self._prolog = chunks[0]
+ # The functions in this module.
+ self._functions = {}
+ # Global variables in this module.
+ self._globals = [
+ x
+ for x in instructions
+ if type(x) is OpVariable or issubclass(type(x), OpConstant)
+ ]
+
+ # Helper dictionaries to get real names of registers, or registers by names.
+ self._name2reg = {}
+ self._reg2name = {}
+ for instruction in instructions:
+ if type(instruction) is OpName:
+ name = instruction.name()
+ reg = instruction.decoratedRegister()
+ self._name2reg[name] = reg
+ self._reg2name[reg] = name
+
+ for chunk in chunks[1:]:
+ function = Function(chunk)
+ assert function.name() not in self._functions
+ self._functions[function.name()] = function
+
+ # Returns the register matching "name" if any, None otherwise.
+ # This assumes names are unique.
+ def getRegisterFromName(self, name):
+ if name in self._name2reg:
+ return self._name2reg[name]
+ return None
+
+ # Returns the name given to "register" if any, None otherwise.
+ def getNameFromRegister(self, register):
+ if register in self._reg2name:
+ return self._reg2name[register]
+ return None
+
+ # Initialize the module before wave execution begins.
+ # See Instruction::static_execution for more details.
+ def initialize(self, lane):
+ for instruction in self._globals:
+ instruction.static_execution(lane)
+
+ # Initialize builtins
+ for instruction in self._prolog:
+ if type(instruction) is OpDecorate:
+ instruction.static_execution(lane)
+
+ def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
+ ip.instruction().runtime_execution(self, lane)
+
+ # Returns the first valid IP for the function defined by the given register.
+ # Calling this with a register not returned by OpFunction is illegal.
+ def get_function_entry(self, register: str) -> InstructionPointer:
+ if register not in self._functions:
+ raise RuntimeError(f"Function defining {register} not found.")
+ return InstructionPointer(self._functions[register], 0, 0)
+
+ # Returns the first valid IP for the basic block defined by register.
+ # Calling this with a register not returned by an OpLabel is illegal.
+ def get_bb_entry(self, register: str) -> InstructionPointer:
+ for name, function in self._functions.items():
+ index = function.get_bb_index(register)
+ if index != -1:
+ return InstructionPointer(function, index, 0)
+ raise RuntimeError(f"Instruction defining {register} not found.")
+
+ # Returns the list of function names in this module.
+ # If an OpName exists for this function, returns the pretty name, else
+ # returns the register name.
+ def get_function_names(self):
+ return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
+
+ # Returns the global variables defined in this module.
+ def variables(self) -> iter:
+ return [x.output_register() for x in self._globals]
+
+ def dump(self, function_name: str = None):
+ print("Module:")
+ print(" globals:")
+ for instruction in self._globals:
+ print(f" {instruction}")
+
+ if function_name is None:
+ print(" functions:")
+ for register, function in self._functions.items():
+ name = self.getNameFromRegister(register)
+ print(f" Function {register} ({name})")
+ function.dump()
+ return
+
+ register = self.getRegisterFromName(function_name)
+ print(f" function {register} ({function_name}):")
+ if register is not None:
+ self._functions[register].dump()
+ else:
+ print(f" error: cannot find function.")
+
+
+# Defines a convergence requirement for the simulation:
+# A list of lanes impacted by a merge and possibly the associated
+# continue target.
+ at dataclass
+class ConvergenceRequirement:
+ mergeTarget: InstructionPointer
+ continueTarget: InstructionPointer
+ impactedLanes: set[int]
+
+
+# Defines a Lane group/Wave in the simulator.
+class Wave:
+ def __init__(self, module, wave_size: int):
+ assert wave_size > 0
+ # The module this wave will execute.
+ self._module = module
+ # The lanes this wave will be composed of.
+ self._lanes = []
+ for i in range(wave_size):
+ self._lanes.append(Lane(self, i))
+
+ # The instructions scheduled for execution.
+ self._tasks: dict(InstructionPointer, list[Lane]) = {}
+ # The actual requirements to comply with when executing instructions.
+ # e.g: the set of lanes required to merge before executing the merge block.
+ self._convergence_requirements = []
+ # The indices of the active lanes for the current executing instruction.
+ self._active_lane_indices = set()
+
+ # Returns True if the given IP can be executed for the given list of lanes.
+ def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
+ merged_lanes = set()
+ for lane in self._lanes:
+ if not lane.running():
+ merged_lanes.add(lane)
+
+ for requirement in self._convergence_requirements:
+ # This task is not executing a merge or continue target.
+ # Adding all lanes at those points into the ignore list.
+ if requirement.mergeTarget != ip and requirement.continueTarget != ip:
+ for tid in requirement.impactedLanes:
+ if self._lanes[tid].ip() == requirement.mergeTarget:
+ merged_lanes.add(tid)
+ if self._lanes[tid].ip() == requirement.continueTarget:
+ merged_lanes.add(tid)
+ continue
+
+ # This task is executing the current requirement continue/merge
+ # target.
+ for tid in requirement.impactedLanes:
+ lane = self._lanes[tid]
+ if not lane.running():
+ continue
+
+ if lane.tid() in merged_lanes:
+ continue
+
+ if ip == requirement.mergeTarget:
+ if lane.ip() != requirement.mergeTarget:
+ return False
+ else:
+ if (
+ lane.ip() != requirement.mergeTarget
+ and lane.ip() != requirement.continueTarget
+ ):
+ return False
+ return True
+
+ # Returns the next task we can schedule. This must always return a task.
+ # Calling this when all lanes are dead is invalid.
+ def _get_next_runnable_task(self):
+ candidate = None
+ for ip, lanes in self._tasks.items():
+ if len(lanes) == 0:
+ continue
+ if self._is_task_candidate(ip, lanes):
+ candidate = ip
+ break
+
+ if candidate:
+ lanes = self._tasks[candidate]
+ del self._tasks[ip]
+ return (candidate, lanes)
+ raise RuntimeError("No task to execute. Deadlock?")
+
+ # Handle an encountered merge instruction for the given lane.
+ def handle_convergence_header(self, lane: Lane, instruction: Instruction):
+ mergeTarget = self._module.get_bb_entry(instruction.merge_location())
+ for requirement in self._convergence_requirements:
+ if requirement.mergeTarget == mergeTarget:
+ requirement.impactedLanes.add(lane.tid())
+ return
+
+ continueTarget = None
+ if instruction.continue_location():
+ continueTarget = self._module.get_bb_entry(instruction.continue_location())
+ requirement = ConvergenceRequirement(
+ mergeTarget, continueTarget, set([lane.tid()])
+ )
+ self._convergence_requirements.append(requirement)
+
+ # Returns true if some instructions are scheduled for execution.
+ def _has_tasks(self):
+ return len(self._tasks) > 0
+
+ # Returns the index of the first active lane right now.
+ def get_first_active_lane_index(self) -> int:
+ return min(self._active_lane_indices)
+
+ # Broadcast the given value to all active lane registers'.
+ def broadcast_register(self, register, value) -> int:
+ for tid in self._active_lane_indices:
+ self._lanes[tid].set_register(register, value)
+
+ # Returns the function associated with 'name'.
+ # Calling this function with an invalid name is illegal.
+ def _get_function_from_name(self, name: str) -> Function:
+ register = self._module.getRegisterFromName(name)
+ assert register is not None
+ return self._module.get_function_entry(register)
+
+ # Run the wave on the function 'function_name' until all lanes are dead.
+ # If verbose is True, execution trace is printed.
+ # Returns the value returned by the function for each lane.
+ def run(self, function_name: str, verbose: bool = False) -> list[int]:
+ for t in self._lanes:
+ self._module.initialize(t)
+
+ function = self._get_function_from_name(function_name)
+ assert function is not None
+ for t in self._lanes:
+ t.do_call(function, "__shader_output__")
+
+ self._tasks[self._lanes[0].ip()] = self._lanes
+ while self._has_tasks():
+ ip, lanes = self._get_next_runnable_task()
+ self._active_lane_indices = set([x.tid() for x in lanes])
+ if verbose:
+ print(
+ f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
+ )
+
+ for lane in lanes:
+ self._module.execute_one_instruction(lane, ip)
+ if not lane.running():
+ continue
+
+ if lane.ip() in self._tasks:
+ self._tasks[lane.ip()].append(lane)
+ else:
+ self._tasks[lane.ip()] = [lane]
+
+ if verbose and ip.instruction().has_output_register():
+ register = ip.instruction().output_register()
+ print(
+ f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
+ )
+
+ output = []
+ for lane in self._lanes:
+ output.append(lane.get_register("__shader_output__"))
+ return output
+
+ def dump_register(self, register):
+ for lane in self._lanes:
+ print(
+ f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
+ )
+
+
+parser = argparse.ArgumentParser(
+ description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+parser.add_argument(
+ "-i", "--input", help="Text SPIR-V to read from", required=False, default="-"
+)
+parser.add_argument("-f", "--function", help="Function to execute")
+parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False)
+parser.add_argument(
+ "-e",
+ "--expects",
+ help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
+)
+parser.add_argument("-v", "--verbose", help="verbose", action="store_true")
+args = parser.parse_args()
+
+
+def load_instructions(filename):
+ if filename is None:
+ return []
+
+ if filename.lstrip().rstrip() != "-":
+ try:
+ with open(filename, "r") as f:
+ lines = f.read().split("\n")
+ except Exception: # (FileNotFoundError, PermissionError):
+ return []
+ else:
+ lines = sys.stdin.readlines()
+
+ # Remove leading/trailing whitespaces.
+ lines = [x.rstrip().lstrip() for x in lines]
----------------
michalpaszkowski wrote:
Could `.rstrip().lstrip()` be replaced with `.strip()`?
https://github.com/llvm/llvm-project/pull/104020
More information about the llvm-commits
mailing list