[Mlir-commits] [mlir] [mlir][spirv] Remove code for de-duplicating symbols in SPIR-V grammar (PR #111778)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 9 18:01:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Caio Oliveira (cmarcelo)

<details>
<summary>Changes</summary>

SPIR-V grammar was updated in upstream to have an "aliases" field instead of duplicating symbols with same values.  See https://github.com/KhronosGroup/SPIRV-Headers/pull/447 for details.

---
Full diff: https://github.com/llvm/llvm-project/pull/111778.diff


1 Files Affected:

- (modified) mlir/utils/spirv/gen_spirv_dialect.py (+10-91) 


``````````diff
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 78c1022428d8a1..917bf08a71f1ce 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -127,44 +127,6 @@ def split_list_into_sublists(items):
     return chuncks
 
 
-def uniquify_enum_cases(lst):
-    """Prunes duplicate enum cases from the list.
-
-    Arguments:
-     - lst: List whose elements are to be uniqued. Assumes each element is a
-       (symbol, value) pair and elements already sorted according to value.
-
-    Returns:
-     - A list with all duplicates removed. The elements are sorted according to
-       value and, for each value, uniqued according to symbol.
-       original list,
-     - A map from deduplicated cases to the uniqued case.
-    """
-    cases = lst
-    uniqued_cases = []
-    duplicated_cases = {}
-
-    # First sort according to the value
-    cases.sort(key=lambda x: x[1])
-
-    # Then group them according to the value
-    for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
-        # For each value, sort according to the enumerant symbol.
-        sorted_group = sorted(groups, key=lambda x: x[0])
-        # Keep the "smallest" case, which is typically the symbol without extension
-        # suffix. But we have special cases that we want to fix.
-        case = sorted_group[0]
-        for i in range(1, len(sorted_group)):
-            duplicated_cases[sorted_group[i][0]] = case[0]
-        if case[0] == "HlslSemanticGOOGLE":
-            assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
-            case = sorted_group[1]
-            duplicated_cases[sorted_group[0][0]] = case[0]
-        uniqued_cases.append(case)
-
-    return uniqued_cases, duplicated_cases
-
-
 def toposort(dag, sort_fn):
     """Topologically sorts the given dag.
 
@@ -197,14 +159,12 @@ def get_next_batch(dag):
     return sorted_nodes
 
 
-def toposort_capabilities(all_cases, capability_mapping):
+def toposort_capabilities(all_cases):
     """Returns topologically sorted capability (symbol, value) pairs.
 
     Arguments:
       - all_cases: all capability cases (containing symbol, value, and implied
         capabilities).
-      - capability_mapping: mapping from duplicated capability symbols to the
-        canonicalized symbol chosen for SPIRVBase.td.
 
     Returns:
       A list containing topologically sorted capability (symbol, value) pairs.
@@ -215,13 +175,10 @@ def toposort_capabilities(all_cases, capability_mapping):
         # Get the current capability.
         cur = case["enumerant"]
         name_to_value[cur] = case["value"]
-        # Ignore duplicated symbols.
-        if cur in capability_mapping:
-            continue
 
         # Get capabilities implied by the current capability.
         prev = case.get("capabilities", [])
-        uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
+        uniqued_prev = set(prev)
         dag[cur] = uniqued_prev
 
     sorted_caps = toposort(dag, lambda x: name_to_value[x])
@@ -229,36 +186,12 @@ def toposort_capabilities(all_cases, capability_mapping):
     return [(c, name_to_value[c]) for c in sorted_caps]
 
 
-def get_capability_mapping(operand_kinds):
-    """Returns the capability mapping from duplicated cases to canonicalized ones.
-
-    Arguments:
-      - operand_kinds: all operand kinds' grammar spec
-
-    Returns:
-      - A map mapping from duplicated capability symbols to the canonicalized
-        symbol chosen for SPIRVBase.td.
-    """
-    # Find the operand kind for capability
-    cap_kind = {}
-    for kind in operand_kinds:
-        if kind["kind"] == "Capability":
-            cap_kind = kind
-
-    kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
-    _, capability_mapping = uniquify_enum_cases(kind_cases)
-
-    return capability_mapping
-
-
-def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
+def get_availability_spec(enum_case, for_op, for_cap):
     """Returns the availability specification string for the given enum case.
 
     Arguments:
       - enum_case: the enum case to generate availability spec for. It may contain
         'version', 'lastVersion', 'extensions', or 'capabilities'.
-      - capability_mapping: mapping from duplicated capability symbols to the
-        canonicalized symbol chosen for SPIRVBase.td.
       - for_op: bool value indicating whether this is the availability spec for an
         op itself.
       - for_cap: bool value indicating whether this is the availability spec for
@@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
     if caps:
         canonicalized_caps = []
         for c in caps:
-            if c in capability_mapping:
-                canonicalized_caps.append(capability_mapping[c])
-            else:
-                canonicalized_caps.append(c)
+            canonicalized_caps.append(c)
         prefixed_caps = [
             "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
         ]
@@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
     return "{}{}{}".format(implies, "\n  " if implies and avail else "", avail)
 
 
-def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
+def gen_operand_kind_enum_attr(operand_kind):
     """Generates the TableGen EnumAttr definition for the given operand kind.
 
     Returns:
@@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name):
         # Special treatment for capability cases: we need to sort them topologically
         # because a capability can refer to another via the 'implies' field.
         kind_cases = toposort_capabilities(
-            operand_kind["enumerants"], capability_mapping
+            operand_kind["enumerants"]
         )
     else:
         kind_cases = [
             (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
         ]
-        kind_cases, _ = uniquify_enum_cases(kind_cases)
     max_len = max([len(symbol) for (symbol, _) in kind_cases])
 
     # Generate the definition for each enum case
@@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name):
             value = int(case_pair[1])
         avail = get_availability_spec(
             name_to_case_dict[name],
-            capability_mapping,
             False,
             kind_name == "Capability",
         )
@@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
     ]
     filter_list.extend(existing_kinds)
 
-    capability_mapping = get_capability_mapping(operand_kinds)
-
     # Generate definitions for all enums in filter list
     defs = [
-        gen_operand_kind_enum_attr(kind, capability_mapping)
+        gen_operand_kind_enum_attr(kind)
         for kind in operand_kinds
         if kind["kind"] in filter_list
     ]
@@ -762,7 +688,7 @@ def get_description(text, appendix):
 
 
 def get_op_definition(
-    instruction, opname, doc, existing_info, capability_mapping, settings
+    instruction, opname, doc, existing_info, settings
 ):
     """Generates the TableGen op definition for the given SPIR-V instruction.
 
@@ -771,8 +697,6 @@ def get_op_definition(
       - doc: the instruction's SPIR-V HTML doc
       - existing_info: a dict containing potential manually specified sections for
         this instruction
-      - capability_mapping: mapping from duplicated capability symbols to the
-                     canonicalized symbol chosen for SPIRVBase.td
 
     Returns:
       - A string containing the TableGen op definition
@@ -840,7 +764,7 @@ def get_op_definition(
     operands = instruction.get("operands", [])
 
     # Op availability
-    avail = get_availability_spec(instruction, capability_mapping, True, False)
+    avail = get_availability_spec(instruction, True, False)
     if avail:
         avail = "\n\n  {0}".format(avail)
 
@@ -1019,7 +943,7 @@ def extract_td_op_info(op_def):
 
 
 def update_td_op_definitions(
-    path, instructions, docs, filter_list, inst_category, capability_mapping, settings
+    path, instructions, docs, filter_list, inst_category, settings
 ):
     """Updates SPIRVOps.td with newly generated op definition.
 
@@ -1028,8 +952,6 @@ def update_td_op_definitions(
       - instructions: SPIR-V JSON grammar for all instructions
       - docs: SPIR-V HTML doc for all instructions
       - filter_list: a list containing new opnames to include
-      - capability_mapping: mapping from duplicated capability symbols to the
-                     canonicalized symbol chosen for SPIRVBase.td.
 
     Returns:
       - A string containing all the TableGen op definitions
@@ -1077,7 +999,6 @@ def update_td_op_definitions(
                     opname,
                     docs[fixed_opname],
                     op_info_dict.get(opname, {"inst_category": inst_category}),
-                    capability_mapping,
                     settings,
                 )
             )
@@ -1184,14 +1105,12 @@ def update_td_op_definitions(
     if args.new_inst is not None:
         assert args.op_td_path is not None
         docs = get_spirv_doc_from_html_spec(ext_html_url, args)
-        capability_mapping = get_capability_mapping(operand_kinds)
         update_td_op_definitions(
             args.op_td_path,
             instructions,
             docs,
             args.new_inst,
             args.inst_category,
-            capability_mapping,
             args,
         )
         print("Done. Note that this script just generates a template; ", end="")

``````````

</details>


https://github.com/llvm/llvm-project/pull/111778


More information about the Mlir-commits mailing list