[Mlir-commits] [mlir] [mlir][spirv] Remove code for de-duplicating symbols in SPIR-V grammar (PR #111778)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 9 18:01:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Caio Oliveira (cmarcelo)
<details>
<summary>Changes</summary>
SPIR-V grammar was updated in upstream to have an "aliases" field instead of duplicating symbols with same values. See https://github.com/KhronosGroup/SPIRV-Headers/pull/447 for details.
---
Full diff: https://github.com/llvm/llvm-project/pull/111778.diff
1 Files Affected:
- (modified) mlir/utils/spirv/gen_spirv_dialect.py (+10-91)
``````````diff
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 78c1022428d8a1..917bf08a71f1ce 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -127,44 +127,6 @@ def split_list_into_sublists(items):
return chuncks
-def uniquify_enum_cases(lst):
- """Prunes duplicate enum cases from the list.
-
- Arguments:
- - lst: List whose elements are to be uniqued. Assumes each element is a
- (symbol, value) pair and elements already sorted according to value.
-
- Returns:
- - A list with all duplicates removed. The elements are sorted according to
- value and, for each value, uniqued according to symbol.
- original list,
- - A map from deduplicated cases to the uniqued case.
- """
- cases = lst
- uniqued_cases = []
- duplicated_cases = {}
-
- # First sort according to the value
- cases.sort(key=lambda x: x[1])
-
- # Then group them according to the value
- for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
- # For each value, sort according to the enumerant symbol.
- sorted_group = sorted(groups, key=lambda x: x[0])
- # Keep the "smallest" case, which is typically the symbol without extension
- # suffix. But we have special cases that we want to fix.
- case = sorted_group[0]
- for i in range(1, len(sorted_group)):
- duplicated_cases[sorted_group[i][0]] = case[0]
- if case[0] == "HlslSemanticGOOGLE":
- assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
- case = sorted_group[1]
- duplicated_cases[sorted_group[0][0]] = case[0]
- uniqued_cases.append(case)
-
- return uniqued_cases, duplicated_cases
-
-
def toposort(dag, sort_fn):
"""Topologically sorts the given dag.
@@ -197,14 +159,12 @@ def get_next_batch(dag):
return sorted_nodes
-def toposort_capabilities(all_cases, capability_mapping):
+def toposort_capabilities(all_cases):
"""Returns topologically sorted capability (symbol, value) pairs.
Arguments:
- all_cases: all capability cases (containing symbol, value, and implied
capabilities).
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
Returns:
A list containing topologically sorted capability (symbol, value) pairs.
@@ -215,13 +175,10 @@ def toposort_capabilities(all_cases, capability_mapping):
# Get the current capability.
cur = case["enumerant"]
name_to_value[cur] = case["value"]
- # Ignore duplicated symbols.
- if cur in capability_mapping:
- continue
# Get capabilities implied by the current capability.
prev = case.get("capabilities", [])
- uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
+ uniqued_prev = set(prev)
dag[cur] = uniqued_prev
sorted_caps = toposort(dag, lambda x: name_to_value[x])
@@ -229,36 +186,12 @@ def toposort_capabilities(all_cases, capability_mapping):
return [(c, name_to_value[c]) for c in sorted_caps]
-def get_capability_mapping(operand_kinds):
- """Returns the capability mapping from duplicated cases to canonicalized ones.
-
- Arguments:
- - operand_kinds: all operand kinds' grammar spec
-
- Returns:
- - A map mapping from duplicated capability symbols to the canonicalized
- symbol chosen for SPIRVBase.td.
- """
- # Find the operand kind for capability
- cap_kind = {}
- for kind in operand_kinds:
- if kind["kind"] == "Capability":
- cap_kind = kind
-
- kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
- _, capability_mapping = uniquify_enum_cases(kind_cases)
-
- return capability_mapping
-
-
-def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
+def get_availability_spec(enum_case, for_op, for_cap):
"""Returns the availability specification string for the given enum case.
Arguments:
- enum_case: the enum case to generate availability spec for. It may contain
'version', 'lastVersion', 'extensions', or 'capabilities'.
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
- for_op: bool value indicating whether this is the availability spec for an
op itself.
- for_cap: bool value indicating whether this is the availability spec for
@@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
if caps:
canonicalized_caps = []
for c in caps:
- if c in capability_mapping:
- canonicalized_caps.append(capability_mapping[c])
- else:
- canonicalized_caps.append(c)
+ canonicalized_caps.append(c)
prefixed_caps = [
"SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
]
@@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
-def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
+def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen EnumAttr definition for the given operand kind.
Returns:
@@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name):
# Special treatment for capability cases: we need to sort them topologically
# because a capability can refer to another via the 'implies' field.
kind_cases = toposort_capabilities(
- operand_kind["enumerants"], capability_mapping
+ operand_kind["enumerants"]
)
else:
kind_cases = [
(case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
]
- kind_cases, _ = uniquify_enum_cases(kind_cases)
max_len = max([len(symbol) for (symbol, _) in kind_cases])
# Generate the definition for each enum case
@@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name):
value = int(case_pair[1])
avail = get_availability_spec(
name_to_case_dict[name],
- capability_mapping,
False,
kind_name == "Capability",
)
@@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
]
filter_list.extend(existing_kinds)
- capability_mapping = get_capability_mapping(operand_kinds)
-
# Generate definitions for all enums in filter list
defs = [
- gen_operand_kind_enum_attr(kind, capability_mapping)
+ gen_operand_kind_enum_attr(kind)
for kind in operand_kinds
if kind["kind"] in filter_list
]
@@ -762,7 +688,7 @@ def get_description(text, appendix):
def get_op_definition(
- instruction, opname, doc, existing_info, capability_mapping, settings
+ instruction, opname, doc, existing_info, settings
):
"""Generates the TableGen op definition for the given SPIR-V instruction.
@@ -771,8 +697,6 @@ def get_op_definition(
- doc: the instruction's SPIR-V HTML doc
- existing_info: a dict containing potential manually specified sections for
this instruction
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td
Returns:
- A string containing the TableGen op definition
@@ -840,7 +764,7 @@ def get_op_definition(
operands = instruction.get("operands", [])
# Op availability
- avail = get_availability_spec(instruction, capability_mapping, True, False)
+ avail = get_availability_spec(instruction, True, False)
if avail:
avail = "\n\n {0}".format(avail)
@@ -1019,7 +943,7 @@ def extract_td_op_info(op_def):
def update_td_op_definitions(
- path, instructions, docs, filter_list, inst_category, capability_mapping, settings
+ path, instructions, docs, filter_list, inst_category, settings
):
"""Updates SPIRVOps.td with newly generated op definition.
@@ -1028,8 +952,6 @@ def update_td_op_definitions(
- instructions: SPIR-V JSON grammar for all instructions
- docs: SPIR-V HTML doc for all instructions
- filter_list: a list containing new opnames to include
- - capability_mapping: mapping from duplicated capability symbols to the
- canonicalized symbol chosen for SPIRVBase.td.
Returns:
- A string containing all the TableGen op definitions
@@ -1077,7 +999,6 @@ def update_td_op_definitions(
opname,
docs[fixed_opname],
op_info_dict.get(opname, {"inst_category": inst_category}),
- capability_mapping,
settings,
)
)
@@ -1184,14 +1105,12 @@ def update_td_op_definitions(
if args.new_inst is not None:
assert args.op_td_path is not None
docs = get_spirv_doc_from_html_spec(ext_html_url, args)
- capability_mapping = get_capability_mapping(operand_kinds)
update_td_op_definitions(
args.op_td_path,
instructions,
docs,
args.new_inst,
args.inst_category,
- capability_mapping,
args,
)
print("Done. Note that this script just generates a template; ", end="")
``````````
</details>
https://github.com/llvm/llvm-project/pull/111778
More information about the Mlir-commits
mailing list