[Mlir-commits] [mlir] [mlir][utils] Update generate-test-checks.py (use SSA names) (PR #136819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 23 00:07:40 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This patch updates generate-test-checks.py to preserve original SSA
names (capitalized) when generating LIT variable names for function
arguments (i.e. for `CHECK-SAME` lines). This improves readability and
helps maintain consistency between the input MLIR and the expected
FileCheck/LIT output.

For example, given the following function:

```mlir
func.func @<!-- -->example(
    %input: memref<4x6x3xf32>,
    %filter: memref<1x3x8xf32>,
    %output: memref<4x2x8xf32>) {

  linalg.conv_1d_nwc_wcf
    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
    outs(%output : memref<4x2x8xf32>)

  return
}
```

The generated output becomes:

```mlir
// CHECK-LABEL: func.func @<!-- -->conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[INPUT:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[FILTER:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[OUTPUT:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[OUTPUT]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```

By contrast, the current version of the script would generate:

```mlir
// CHECK-LABEL: func.func @<!-- -->conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[VAL_0:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[VAL_2:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[VAL_2]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```


---
Full diff: https://github.com/llvm/llvm-project/pull/136819.diff


1 Files Affected:

- (modified) mlir/utils/generate-test-checks.py (+12-8) 


``````````diff
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index d157af9c3cab7..649d76e4e65a2 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python3
 """A script to generate FileCheck statements for mlir unit tests.
-
 This script is a utility to add FileCheck patterns to an mlir file.
 
 NOTE: The input .mlir is expected to be the output from the parser, not a
@@ -77,13 +76,16 @@ def generate_in_parent_scope(self, n):
         self.generate_in_parent_scope_left = n
 
     # Generate a substitution name for the given ssa value name.
-    def generate_name(self, source_variable_name):
+    def generate_name(self, source_variable_name, use_ssa_name):
 
         # Compute variable name
         variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
         if variable_name == '':
-            variable_name = "VAL_" + str(self.name_counter)
-            self.name_counter += 1
+            if use_ssa_name:
+                variable_name = source_variable_name.upper()
+            else:
+                variable_name = "VAL_" + str(self.name_counter)
+                self.name_counter += 1
 
         # Scope where variable name is saved
         scope = len(self.scopes) - 1
@@ -158,7 +160,7 @@ def get_num_ssa_results(input_line):
 
 
 # Process a line of input that has been split at each SSA identifier '%'.
-def process_line(line_chunks, variable_namer, strict_name_re=False):
+def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
     output_line = ""
 
     # Process the rest that contained an SSA value name.
@@ -178,7 +180,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
             output_line += "%[[" + variable + "]]"
         else:
             # Otherwise, generate a new variable.
-            variable = variable_namer.generate_name(ssa_name)
+            variable = variable_namer.generate_name(ssa_name, use_ssa_name)
             if strict_name_re:
                 # Use stricter regexp for the variable name, if requested.
                 # Greedy matching may cause issues with the generic '.*'
@@ -415,9 +417,11 @@ def main():
                 pad_depth = label_length if label_length < 21 else 4
                 output_line += " " * pad_depth
 
-                # Process the rest of the line.
+                # Process the rest of the line. Use the original SSA name to generate the LIT
+                # variable names.
+                use_ssa_names=True
                 output_line += process_line(
-                    [argument], variable_namer, args.strict_name_re
+                    [argument], variable_namer, use_ssa_names, args.strict_name_re
                 )
 
         # Append the output line.

``````````

</details>


https://github.com/llvm/llvm-project/pull/136819


More information about the Mlir-commits mailing list