[Mlir-commits] [mlir] dea01f5 - New features and bug fix in MLIR test generation tool

Eric Kunze llvmlistbot at llvm.org
Fri Jul 7 11:35:06 PDT 2023


Author: Rafael Ubal Tena
Date: 2023-07-07T18:15:11Z
New Revision: dea01f5e00e45dec4319475a001024c6ee882283

URL: https://github.com/llvm/llvm-project/commit/dea01f5e00e45dec4319475a001024c6ee882283
DIFF: https://github.com/llvm/llvm-project/commit/dea01f5e00e45dec4319475a001024c6ee882283.diff

LOG: New features and bug fix in MLIR test generation tool

- Option `--variable_names <names>` allows the user to pass names for FileCheck
  regexps representing variables. Variable names are separated by commas, and
  empty names can be used to generate specific variable names automatically.
  For example, `--variable-names arg_0,arg_1,,,result` will produce regexp names
  `ARG_0`, `ARG_1`, `VAR_0`, `VAR_1`, `RESULT`, `VAR_2`, `VAR_3`, ...

- Option '--attribute_names <names>' can be used to generate global regexp names
  to represent attributes. Useful for affine maps. Same behavior as
  '--variable_names'.

- Bug fixed for scope detection of SSA variables in ops with nested regions that
  return SSA values (e.g., 'linalg.generic'). Originally, returned SSA values were
  inserted in the nested scope.

This version of the tool has been used to generate unit tests for the following
patch: https://reviews.llvm.org/D153291

For example, the main body of the test named 'test_select_2d_one_dynamic' was
generated using the following command:

```
$ mlir-opt -pass-pipeline='builtin.module(func.func(tosa-to-linalg))' test_select_2d_one_dynamic.tosa.mlir | generate-test-checks.py --attribute_names map0,map1,map2 --variable_names arg0,arg1,arg2,const1,arg0_dim1,arg1_dim1,,arg2_dim1,max_dim1,,,arg0_broadcast,,,,,,,arg1_broadcast,,,,,,,arg2_broadcast,,,,,,result
```

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D154458

Added: 
    

Modified: 
    mlir/utils/generate-test-checks.py

Removed: 
    


################################################################################
diff  --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index 0210d7a56ebf51..2f3293952af637 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -45,20 +45,60 @@
 SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
 SSA_RE = re.compile(SSA_RE_STR)
 
+# Regex matching the left-hand side of an assignment
+SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
+SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
+
+# Regex matching attributes
+ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
+ATTR_RE = re.compile(ATTR_RE_STR)
+
+# Regex matching the left-hand side of an attribute definition
+ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
+ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
+
 
 # Class used to generate and manage string substitution blocks for SSA value
 # names.
-class SSAVariableNamer:
-    def __init__(self):
+class VariableNamer:
+    def __init__(self, variable_names):
         self.scopes = []
         self.name_counter = 0
 
+        # Number of variable names to still generate in parent scope
+        self.generate_in_parent_scope_left = 0
+
+        # Parse variable names
+        self.variable_names = [name.upper() for name in variable_names.split(',')]
+        self.used_variable_names = set()
+
+    # Generate the following 'n' variable names in the parent scope. 
+    def generate_in_parent_scope(self, n):
+        self.generate_in_parent_scope_left = n
+
     # Generate a substitution name for the given ssa value name.
-    def generate_name(self, ssa_name):
-        variable = "VAL_" + str(self.name_counter)
-        self.name_counter += 1
-        self.scopes[-1][ssa_name] = variable
-        return variable
+    def generate_name(self, source_variable_name):
+
+        # Compute variable name
+        variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
+        if variable_name == '':
+            variable_name = "VAL_" + str(self.name_counter)
+            self.name_counter += 1
+
+        # Scope where variable name is saved
+        scope = len(self.scopes) - 1
+        if self.generate_in_parent_scope_left > 0:
+            self.generate_in_parent_scope_left -= 1
+            scope = len(self.scopes) - 2
+        assert(scope >= 0)
+
+        # Save variable
+        if variable_name in self.used_variable_names:
+            raise RuntimeError(variable_name + ': duplicate variable name')
+        self.scopes[scope][source_variable_name] = variable_name
+        self.used_variable_names.add(variable_name)
+
+        return variable_name
 
     # Push a new variable name scope.
     def push_name_scope(self):
@@ -76,6 +116,46 @@ def num_scopes(self):
     def clear_counter(self):
         self.name_counter = 0
 
+class AttributeNamer:
+
+    def __init__(self, attribute_names):
+        self.name_counter = 0
+        self.attribute_names = [name.upper() for name in attribute_names.split(',')]
+        self.map = {}
+        self.used_attribute_names = set()
+    
+    # Generate a substitution name for the given attribute name.
+    def generate_name(self, source_attribute_name):
+
+        # Compute FileCheck name
+        attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
+        if attribute_name == '':
+            attribute_name = "ATTR_" + str(self.name_counter)
+            self.name_counter += 1
+
+        # Prepend global symbol
+        attribute_name = '$' + attribute_name
+
+        # Save attribute
+        if attribute_name in self.used_attribute_names:
+            raise RuntimeError(attribute_name + ': duplicate attribute name')
+        self.map[source_attribute_name] = attribute_name
+        self.used_attribute_names.add(attribute_name)
+        return attribute_name
+
+    # Get the saved substitution name for the given attribute name. If no name
+    # has been generated for the given attribute yet, the source attribute name
+    # itself is returned.
+    def get_name(self, source_attribute_name):
+        return self.map[source_attribute_name] if source_attribute_name in self.map else '?'
+
+# Return the number of SSA results in a line of type
+#   %0, %1, ... = ...
+# The function returns 0 if there are no results.
+def get_num_ssa_results(input_line):
+    m = SSA_RESULTS_RE.match(input_line)
+    return m.group().count('%') if m else 0
+
 
 # Process a line of input that has been split at each SSA identifier '%'.
 def process_line(line_chunks, variable_namer):
@@ -84,7 +164,7 @@ def process_line(line_chunks, variable_namer):
     # Process the rest that contained an SSA value name.
     for chunk in line_chunks:
         m = SSA_RE.match(chunk)
-        ssa_name = m.group(0)
+        ssa_name = m.group(0) if m is not None else ''
 
         # Check if an existing variable exists for this name.
         variable = None
@@ -126,6 +206,25 @@ def process_source_lines(source_lines, note, args):
         source_segments[-1].append(line + "\n")
     return source_segments
 
+def process_attribute_definition(line, attribute_namer, output):
+    m = ATTR_DEF_RE.match(line)
+    if m:
+        attribute_name = attribute_namer.generate_name(m.group(1))
+        line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
+        output.write(line)
+
+def process_attribute_references(line, attribute_namer):
+
+    output_line = ''
+    components = ATTR_RE.split(line)
+    for component in components:
+        m = ATTR_RE.match(component)
+        if m:
+            output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]'
+            output_line += component[len(m.group()):]
+        else:
+            output_line += component
+    return output_line
 
 # Pre-process a line of input to remove any character sequences that will be
 # problematic with FileCheck.
@@ -171,6 +270,20 @@ def main():
         'it omits "module {"',
     )
     parser.add_argument("-i", "--inplace", action="store_true", default=False)
+    parser.add_argument(
+        "--variable_names",
+        type=str,
+        default='',
+        help="Names to be used in FileCheck regular expression to represent SSA "
+        "variables in the order they are encountered. Separate names with commas, "
+        "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
+    parser.add_argument(
+        "--attribute_names",
+        type=str,
+        default='',
+        help="Names to be used in FileCheck regular expression to represent "
+        "attributes in the order they are defined. Separate names with commas,"
+        "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
 
     args = parser.parse_args()
 
@@ -197,15 +310,22 @@ def main():
         output = args.output
 
     output_segments = [[]]
-    # A map containing data used for naming SSA value names.
-    variable_namer = SSAVariableNamer()
+
+    # Namers
+    variable_namer = VariableNamer(args.variable_names)
+    attribute_namer = AttributeNamer(args.attribute_names)
+
+    # Process lines
     for input_line in input_lines:
         if not input_line:
             continue
-        lstripped_input_line = input_line.lstrip()
+
+        # Check if this is an attribute definition and process it
+        process_attribute_definition(input_line, attribute_namer, output)
 
         # Lines with blocks begin with a ^. These lines have a trailing comment
         # that needs to be stripped.
+        lstripped_input_line = input_line.lstrip()
         is_block = lstripped_input_line[0] == "^"
         if is_block:
             input_line = input_line.rsplit("//", 1)[0].rstrip()
@@ -222,6 +342,10 @@ def main():
             variable_namer.push_name_scope()
             if cur_level == args.starts_from_scope:
                 output_segments.append([])
+           
+            # Result SSA values must still be pushed to parent scope
+            num_ssa_results = get_num_ssa_results(input_line)
+            variable_namer.generate_in_parent_scope(num_ssa_results)
 
         # Omit lines at the near top level e.g. "module {".
         if cur_level < args.starts_from_scope:
@@ -234,6 +358,9 @@ def main():
         # FileCheck.
         input_line = preprocess_line(input_line)
 
+        # Process uses of attributes in this line
+        input_line = process_attribute_references(input_line, attribute_namer)
+
         # Split the line at the each SSA value name.
         ssa_split = input_line.split("%")
 


        


More information about the Mlir-commits mailing list