[Mlir-commits] [mlir] [mlir][utils] Update generate-test-checks.py (use SSA names) (PR #136819)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Apr 24 08:23:10 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/136819

>From ba380d5da91dda5de866acba489ae03740614886 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 23 Apr 2025 06:55:49 +0000
Subject: [PATCH 1/2] [mlir][utils] Update generate-test-checks.py (use SSA
 names)

This patch updates generate-test-checks.py to preserve original SSA
names (capitalized) when generating LIT variable names for function
arguments (i.e. for `CHECK-SAME` lines). This improves readability and
helps maintain consistency between the input MLIR and the expected
FileCheck/LIT output.

For example, given the following function:

```mlir
func.func @example(
    %input: memref<4x6x3xf32>,
    %filter: memref<1x3x8xf32>,
    %output: memref<4x2x8xf32>) {

  linalg.conv_1d_nwc_wcf
    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
    outs(%output : memref<4x2x8xf32>)

  return
}
```

The generated output becomes:

```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[INPUT:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[FILTER:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[OUTPUT:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[OUTPUT]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```

By contrast, the current version of the script would generate:

```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME:      %[[VAL_0:.*]]: memref<4x6x3xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: memref<1x3x8xf32>,
// CHECK-SAME:      %[[VAL_2:.*]]: memref<4x2x8xf32>) {
// CHECK:         linalg.conv_1d_nwc_wcf
// CHECK:           {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK:           ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK:           outs(%[[VAL_2]] : memref<4x2x8xf32>)
// CHECK:         return
// CHECK:       }
```
---
 mlir/utils/generate-test-checks.py | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index d157af9c3cab7..72167c27f11de 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python3
 """A script to generate FileCheck statements for mlir unit tests.
-
 This script is a utility to add FileCheck patterns to an mlir file.
 
 NOTE: The input .mlir is expected to be the output from the parser, not a
@@ -77,13 +76,16 @@ def generate_in_parent_scope(self, n):
         self.generate_in_parent_scope_left = n
 
     # Generate a substitution name for the given ssa value name.
-    def generate_name(self, source_variable_name):
+    def generate_name(self, source_variable_name, use_ssa_name):
 
         # Compute variable name
         variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
         if variable_name == '':
-            variable_name = "VAL_" + str(self.name_counter)
-            self.name_counter += 1
+            if use_ssa_name:
+                variable_name = source_variable_name.upper()
+            else:
+                variable_name = "VAL_" + str(self.name_counter)
+                self.name_counter += 1
 
         # Scope where variable name is saved
         scope = len(self.scopes) - 1
@@ -158,7 +160,7 @@ def get_num_ssa_results(input_line):
 
 
 # Process a line of input that has been split at each SSA identifier '%'.
-def process_line(line_chunks, variable_namer, strict_name_re=False):
+def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
     output_line = ""
 
     # Process the rest that contained an SSA value name.
@@ -178,7 +180,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
             output_line += "%[[" + variable + "]]"
         else:
             # Otherwise, generate a new variable.
-            variable = variable_namer.generate_name(ssa_name)
+            variable = variable_namer.generate_name(ssa_name, use_ssa_name)
             if strict_name_re:
                 # Use stricter regexp for the variable name, if requested.
                 # Greedy matching may cause issues with the generic '.*'
@@ -415,9 +417,11 @@ def main():
                 pad_depth = label_length if label_length < 21 else 4
                 output_line += " " * pad_depth
 
-                # Process the rest of the line.
+                # Process the rest of the line. Use the original SSA name to generate the LIT
+                # variable names.
+                use_ssa_names = True
                 output_line += process_line(
-                    [argument], variable_namer, args.strict_name_re
+                    [argument], variable_namer, use_ssa_names, args.strict_name_re
                 )
 
         # Append the output line.

>From b34141a321eef47fb86e3533a81c307a448f13dd Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 24 Apr 2025 15:16:08 +0000
Subject: [PATCH 2/2] Make sure that for MLIR names like `%0`, valid LIT var
 names are generated.

---
 mlir/utils/generate-test-checks.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index 72167c27f11de..11fb4e40072e7 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 """A script to generate FileCheck statements for mlir unit tests.
+
 This script is a utility to add FileCheck patterns to an mlir file.
 
 NOTE: The input .mlir is expected to be the output from the parser, not a
@@ -81,7 +82,11 @@ def generate_name(self, source_variable_name, use_ssa_name):
         # Compute variable name
         variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
         if variable_name == '':
-            if use_ssa_name:
+            # If `use_ssa_name` is set, use the MLIR SSA value name to generate
+            # a FileCHeck substation string. As FileCheck requires these
+            # strings to start with a character, skip MLIR variables starting
+            # with a digit (e.g. `%0`).
+            if use_ssa_name and source_variable_name[0].isalpha():
                 variable_name = source_variable_name.upper()
             else:
                 variable_name = "VAL_" + str(self.name_counter)



More information about the Mlir-commits mailing list