[Mlir-commits] [mlir] 7ff3d3b - [mlir][utils] Update generate-test-checks.py (use SSA names) (#136819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 25 09:20:07 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-04-25T17:20:04+01:00
New Revision: 7ff3d3bd1d5b0c9096377bf4c89d0de043dec805

URL: https://github.com/llvm/llvm-project/commit/7ff3d3bd1d5b0c9096377bf4c89d0de043dec805
DIFF: https://github.com/llvm/llvm-project/commit/7ff3d3bd1d5b0c9096377bf4c89d0de043dec805.diff

LOG: [mlir][utils] Update generate-test-checks.py (use SSA names) (#136819)

This patch updates generate-test-checks.py to preserve original SSA
names (capitalized) when generating LIT variable names for function
arguments (i.e. for `CHECK-SAME` lines). This improves readability and
helps maintain consistency between the input MLIR and the expected
FileCheck/LIT output.

For example, given the following function:

```mlir
func.func @example(
    %input: memref<4x6x3xf32>,
    %filter: memref<1x3x8xf32>,
    %output: memref<4x2x8xf32>) {

  linalg.conv_1d_nwc_wcf
    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
    outs(%output : memref<4x2x8xf32>)

  return
}
```

The generated output becomes:

```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[INPUT:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[FILTER:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[OUTPUT:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[OUTPUT]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```

By contrast, the current version of the script would generate:

```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[VAL_0:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[VAL_2:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[VAL_2]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```

Added: 
    

Modified: 
    mlir/utils/generate-test-checks.py

Removed: 
    


################################################################################
diff  --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index d157af9c3cab7..11fb4e40072e7 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -77,13 +77,20 @@ def generate_in_parent_scope(self, n):
         self.generate_in_parent_scope_left = n
 
     # Generate a substitution name for the given ssa value name.
-    def generate_name(self, source_variable_name):
+    def generate_name(self, source_variable_name, use_ssa_name):
 
         # Compute variable name
         variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
         if variable_name == '':
-            variable_name = "VAL_" + str(self.name_counter)
-            self.name_counter += 1
+            # If `use_ssa_name` is set, use the MLIR SSA value name to generate
+            # a FileCHeck substation string. As FileCheck requires these
+            # strings to start with a character, skip MLIR variables starting
+            # with a digit (e.g. `%0`).
+            if use_ssa_name and source_variable_name[0].isalpha():
+                variable_name = source_variable_name.upper()
+            else:
+                variable_name = "VAL_" + str(self.name_counter)
+                self.name_counter += 1
 
         # Scope where variable name is saved
         scope = len(self.scopes) - 1
@@ -158,7 +165,7 @@ def get_num_ssa_results(input_line):
 
 
 # Process a line of input that has been split at each SSA identifier '%'.
-def process_line(line_chunks, variable_namer, strict_name_re=False):
+def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
     output_line = ""
 
     # Process the rest that contained an SSA value name.
@@ -178,7 +185,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
             output_line += "%[[" + variable + "]]"
         else:
             # Otherwise, generate a new variable.
-            variable = variable_namer.generate_name(ssa_name)
+            variable = variable_namer.generate_name(ssa_name, use_ssa_name)
             if strict_name_re:
                 # Use stricter regexp for the variable name, if requested.
                 # Greedy matching may cause issues with the generic '.*'
@@ -415,9 +422,11 @@ def main():
                 pad_depth = label_length if label_length < 21 else 4
                 output_line += " " * pad_depth
 
-                # Process the rest of the line.
+                # Process the rest of the line. Use the original SSA name to generate the LIT
+                # variable names.
+                use_ssa_names = True
                 output_line += process_line(
-                    [argument], variable_namer, args.strict_name_re
+                    [argument], variable_namer, use_ssa_names, args.strict_name_re
                 )
 
         # Append the output line.


        


More information about the Mlir-commits mailing list