[Mlir-commits] [mlir] [mlir][tblgen] Adds support for embedded LIT tests in TableGen records (PR #158017)
Kshitij Jain
llvmlistbot at llvm.org
Sat Nov 1 17:21:27 PDT 2025
https://github.com/jkshtj updated https://github.com/llvm/llvm-project/pull/158017
>From 73613958573a098e4b3785bad0d25528404b6164 Mon Sep 17 00:00:00 2001
From: Kshitij <kjain at d-matrix.ai>
Date: Thu, 11 Sep 2025 08:32:03 +0000
Subject: [PATCH 1/2] [mlir][tblgen] Adds support for embedded LIT tests in
TableGen records
Introduces a new Testable base class that allows TableGen records (starting
with Pass records) to embed LIT test definitions directly within their
definitions. This enables co-locating tests with pass definitions for better
maintainability.
Key components:
- Testable.td: Base class for records that can have embedded tests
- LitTestGen.cpp: TableGen backend to extract and generate LIT test files
- AddMLIR.cmake: CMake function to process embedded tests with usage examples
- PassBase.td: Updated Pass class to extend Testable
Usage example in CMake:
add_embedded_lit_tests(MyPassesEmbeddedTests
${CMAKE_CURRENT_SOURCE_DIR}/include/MyPasses.td
${CMAKE_CURRENT_SOURCE_DIR}/test/Passes/)
# Add LIT test generation target as a dependency to some
# other target
add_library(someLib DEPENDS MyPassesEmbeddedTests)
---
mlir/cmake/modules/AddMLIR.cmake | 100 +++++++++++++++
mlir/include/mlir/IR/Testable.td | 40 ++++++
mlir/include/mlir/Pass/PassBase.td | 4 +-
mlir/test/mlir-tblgen/gen-lit-tests.td | 65 ++++++++++
mlir/tools/mlir-tblgen/CMakeLists.txt | 1 +
mlir/tools/mlir-tblgen/LitTestGen.cpp | 170 +++++++++++++++++++++++++
6 files changed, 379 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/IR/Testable.td
create mode 100644 mlir/test/mlir-tblgen/gen-lit-tests.td
create mode 100644 mlir/tools/mlir-tblgen/LitTestGen.cpp
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 6589458ab7894..9b05b70231dba 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -762,3 +762,103 @@ function(mlir_target_link_libraries target type)
target_link_libraries(${target} ${type} ${ARGN})
endif()
endfunction()
+
+# Extracts LIT tests embedded in `Testable` records in `tblgen_file`
+# and generates a file per test in `output_dir`
+#
+# Example usage:
+# # Extract tests from MyPasses.td and generate them in test/Passes/
+# add_embedded_lit_tests(MyPassesEmbeddedTests
+# ${CMAKE_CURRENT_SOURCE_DIR}/include/MyPasses.td
+# ${CMAKE_CURRENT_SOURCE_DIR}/test/Passes/)
+#
+# # This will:
+# # 1. Process MyPasses.td with mlir-tblgen --gen-lit-tests
+# # 2. Extract individual test files to test/Passes/
+# # 3. Generate files like: test/Passes/generated_MyPass_test1.mlir
+#
+function(add_embedded_lit_tests target tblgen_file output_dir)
+ set(LLVM_TARGET_DEFINITIONS ${tblgen_file})
+
+ # Extraction script content
+ set(EXTRACT_SCRIPT_CONTENT [[
+ # Generated extraction script
+ if(NOT CONSOLIDATED_FILE)
+ message(FATAL_ERROR "CONSOLIDATED_FILE variable is required")
+ endif()
+
+ if(NOT OUTPUT_DIR)
+ message(FATAL_ERROR "OUTPUT_DIR variable is required")
+ endif()
+
+ if(NOT EXISTS ${CONSOLIDATED_FILE})
+ message(FATAL_ERROR "Consolidated file does not exist: ${CONSOLIDATED_FILE}")
+ endif()
+
+ # Read the consolidated file
+ file(READ ${CONSOLIDATED_FILE} file_content)
+
+ # Split into lines for processing
+ string(REPLACE "\n" ";" lines "${file_content}")
+
+ set(current_filename "")
+ set(current_content "")
+ set(in_test_block FALSE)
+ set(extracted_test_files)
+
+ foreach(line IN LISTS lines)
+ # Check for filename line
+ if(line MATCHES "^// File: (.+)$")
+ set(current_filename "${CMAKE_MATCH_1}")
+ endif()
+
+ # Check for BEGIN marker
+ if(line MATCHES "^// --- BEGIN .+ ---$")
+ set(in_test_block TRUE)
+ set(current_content "")
+ # Check for END marker
+ elseif(line MATCHES "^// --- END .+ ---$")
+ set(in_test_block FALSE)
+
+ # Write the extracted content to file
+ if(current_filename AND current_content)
+ file(MAKE_DIRECTORY ${OUTPUT_DIR})
+ file(WRITE ${OUTPUT_DIR}/${current_filename} "${current_content}")
+ message(STATUS "Extracted test file: ${current_filename}")
+ list(APPEND extracted_test_files ${current_filename})
+ endif()
+
+ set(current_filename "")
+ set(current_content "")
+ # Collect content within BEGIN/END block
+ elseif(in_test_block)
+ string(APPEND current_content "${line}\n")
+ endif()
+ endforeach()
+
+ list(LENGTH extracted_test_files num_extracted_files)
+ message(STATUS "Extracted ${num_extracted_files} test files to ${OUTPUT_DIR}")
+ ]])
+
+ # Write extraction script to a file in the build directory
+ file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake "${EXTRACT_SCRIPT_CONTENT}")
+
+ # Process tblgen_file and generate a file with all embedded LIT
+ # tests in tblgen_file
+ get_filename_component(tblgen_name ${tblgen_file} NAME_WE)
+ set(consolidated_output_file ${tblgen_name}_extracted_lit_tests.txt)
+ mlir_tablegen(${consolidated_output_file} --gen-lit-tests)
+
+ # Add public tablegen target to trigger builds on changes in tblgen_file
+ add_public_tablegen_target(${target})
+
+ # Call the extraction script to extract all LIT tests into individual
+ # `.mlir` test files
+ add_custom_command(TARGET ${target} POST_BUILD
+ COMMAND ${CMAKE_COMMAND}
+ -DCONSOLIDATED_FILE=${CMAKE_CURRENT_BINARY_DIR}/${consolidated_output_file}
+ -DOUTPUT_DIR=${output_dir}
+ -P ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake
+ COMMENT "Extracting LIT tests to individual files"
+ )
+endfunction()
\ No newline at end of file
diff --git a/mlir/include/mlir/IR/Testable.td b/mlir/include/mlir/IR/Testable.td
new file mode 100644
index 0000000000000..15814ed1bd939
--- /dev/null
+++ b/mlir/include/mlir/IR/Testable.td
@@ -0,0 +1,40 @@
+//===-- Testable.td - Testable type definition file --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the `Testable` type.
+//
+// Any type whose records can have corresponding LIT tests (eg - Pass) can extend
+// `Testable` in order to be able to embed LIT tests within record definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTABLE
+#define TESTABLE
+
+// Represents a LIT test record in TableGen
+class LitTest<string name, code snippet, list<string> run = [], list<string> check = []> {
+ // The name of the generated test file
+ string testFileName = name;
+
+ // The IR snippet/code to be tested
+ code irSnippet = snippet;
+
+ // The RUN commands for the test (e.g., "mlir-opt %s")
+ list<string> runLines = run;
+
+ // Expected output patterns (CHECK lines)
+ list<string> checkLines = check;
+}
+
+// Base class for elements that can have auto-generated LIT tests
+class Testable {
+ // List of LIT tests associated with this element
+ list<LitTest> tests = [];
+}
+
+#endif // TESTABLE
\ No newline at end of file
diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index e37f9735e2241..50ea44419ca24 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -14,6 +14,8 @@
#ifndef MLIR_PASS_PASSBASE
#define MLIR_PASS_PASSBASE
+include "mlir/IR/Testable.td"
+
//===----------------------------------------------------------------------===//
// Options
//===----------------------------------------------------------------------===//
@@ -62,7 +64,7 @@ class Statistic<string varName, string statName, string desc> {
// Pass
//===----------------------------------------------------------------------===//
-class PassBase<string passArg, string base> {
+class PassBase<string passArg, string base> : Testable {
// The command line argument of the pass.
string argument = passArg;
diff --git a/mlir/test/mlir-tblgen/gen-lit-tests.td b/mlir/test/mlir-tblgen/gen-lit-tests.td
new file mode 100644
index 0000000000000..40a03fb2b2d60
--- /dev/null
+++ b/mlir/test/mlir-tblgen/gen-lit-tests.td
@@ -0,0 +1,65 @@
+// RUN: mlir-tblgen -gen-lit-tests -I %S/../../include -dialect=test %s | FileCheck %s
+
+include "mlir/Pass/PassBase.td"
+include "mlir/IR/Testable.td"
+
+def TestPassWithEmbeddedLitTests : Pass<"test-pass-with-embedded-lit-tests"> {
+ let summary = "pass summary";
+ let description = [{
+ Pass description
+ }];
+
+ let tests = [
+ LitTest<
+ "lit_test_file_1.mlir",
+ [{
+ func.func @test1() {
+ return 42;
+ }
+ }],
+ [
+ "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
+ ],
+ [
+ "// RANDOM-CHECK-LABEL: func.func @test1",
+ ]
+ >,
+ LitTest<
+ "lit_test_file_2.mlir",
+ [{
+ func.func @test2() {
+ return 42;
+ }
+ }],
+ [
+ "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
+ ],
+ [
+ "// RANDOM-CHECK-LABEL: func.func @test2",
+ ]
+ >,
+ ];
+}
+
+// CHECK-LABEL: // Generated 2 LIT test files
+// CHECK: // Use the following files for LIT testing:
+
+// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir
+// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
+// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
+// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
+// CHECK: func.func @test1() {
+// CHECK: return 42;
+// CHECK: }
+// CHECK: // RANDOM-CHECK-LABEL: func.func @test1
+// CHECK: --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
+
+// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir
+// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
+// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
+// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
+// CHECK: func.func @test2() {
+// CHECK: return 42;
+// CHECK: }
+// CHECK: // RANDOM-CHECK-LABEL: func.func @test2
+// CHECK: // --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
\ No newline at end of file
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e0576c8..e721f1e26a2bd 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -16,6 +16,7 @@ add_tablegen(mlir-tblgen MLIR
EnumsGen.cpp
EnumPythonBindingGen.cpp
FormatGen.cpp
+ LitTestGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
mlir-tblgen.cpp
diff --git a/mlir/tools/mlir-tblgen/LitTestGen.cpp b/mlir/tools/mlir-tblgen/LitTestGen.cpp
new file mode 100644
index 0000000000000..49a092fa9879f
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/LitTestGen.cpp
@@ -0,0 +1,170 @@
+//===- LitTestGen.cpp - LIT test generator ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// LitTestGen extracts `LitTest` records from `Testable` TableGen records and
+// generates corresponding LIT test files.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Pass.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Path.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+#include <set>
+
+using namespace mlir;
+using namespace mlir::tblgen;
+using llvm::formatv;
+using llvm::RecordKeeper;
+
+static llvm::cl::OptionCategory litTestGenCategory("Options for -gen-lit-tests");
+static llvm::cl::opt<std::string>
+ outputDir("output-dir",
+ llvm::cl::desc("Output directory for generated test files"),
+ llvm::cl::cat(litTestGenCategory),
+ llvm::cl::value_desc("directory"));
+
+
+/// Cpp type corresponding to the `LitTest` record type in TableGen
+struct LitTest {
+ std::string sourceDefName;
+ std::string testFileName;
+ std::string irSnippet;
+ llvm::SmallVector<std::string> runLines;
+ llvm::SmallVector<std::string> checkLines;
+};
+
+static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *record,
+ llvm::StringRef dialectName = "") {
+ llvm::SmallVector<LitTest> tests;
+
+ // Check if the record has a tests field
+ const llvm::RecordVal *testsVal = record->getValue("tests");
+ if (!testsVal)
+ return tests;
+
+ const llvm::ListInit *testsList =
+ llvm::dyn_cast_or_null<llvm::ListInit>(testsVal->getValue());
+ if (!testsList)
+ return tests;
+
+ for (const llvm::Init *init : testsList->getElements()) {
+ const llvm::DefInit *defInit = llvm::dyn_cast<llvm::DefInit>(init);
+ if (!defInit)
+ continue;
+
+ const llvm::Record *testRec = defInit->getDef();
+
+ // Extract fields from LitTest record
+ std::string name = testRec->getValueAsString("testFileName").str();
+ std::string irSnippet = testRec->getValueAsString("irSnippet").str();
+
+ llvm::SmallVector<std::string> runLines;
+ llvm::for_each(*testRec->getValueAsListInit("runLines"), [&](const llvm::Init *init) {
+ runLines.emplace_back(llvm::cast<llvm::StringInit>(init)->getValue());
+ });
+
+ llvm::SmallVector<std::string> checkLines;
+ llvm::for_each(*testRec->getValueAsListInit("checkLines"), [&](const llvm::Init *init) {
+ checkLines.emplace_back(llvm::cast<llvm::StringInit>(init)->getValue());
+ });
+
+ tests.push_back(LitTest {
+ record->getName().str(),
+ name,
+ irSnippet,
+ runLines,
+ checkLines,
+ });
+ }
+
+ return tests;
+}
+
+/// Extract tests from passes
+static llvm::SmallVector<LitTest> extractPassTests(const RecordKeeper &records) {
+ llvm::SmallVector<LitTest> tests;
+
+ // Check if PassBase class exists before trying to get derived definitions
+ if (records.getClass("PassBase")) {
+ for (const llvm::Record *def : records.getAllDerivedDefinitions("PassBase")) {
+ if (def->isAnonymous())
+ continue;
+
+ auto passTests = extractTestsFromRecord(def, "passes");
+ tests.insert(tests.end(), passTests.begin(), passTests.end());
+ }
+ }
+
+ return tests;
+}
+
+/// Generate a LIT test file for an IR test
+static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) {
+ // Add RUN lines
+ for (const auto& runLine : test.runLines) {
+ os << "\n" << runLine << "\n";
+ }
+
+ os << "// Generated from TableGen definition: " << test.sourceDefName << "\n\n";
+
+ // Add the test body
+ os << test.irSnippet << "\n";
+
+ // Add CHECK lines
+ for (const auto& checkLine : test.checkLines) {
+ os << "\n" << checkLine << "\n";
+ }
+}
+
+/// Main function to generate all IR test test files
+static void generateLitTests(const RecordKeeper &records, raw_ostream &os) {
+ llvm::SmallVector<LitTest> allTests;
+
+ // Extract tests from different definition types (only passes for now)
+ auto passTests = extractPassTests(records);
+
+ allTests.insert(allTests.end(), passTests.begin(), passTests.end());
+
+ if (allTests.empty()) {
+ os << "// No LitTest record found in any TableGen definition\n";
+ return;
+ }
+
+ // Generate summary
+ os << "// Generated " << allTests.size() << " LIT test files\n";
+ os << "// Use the following files for LIT testing:\n\n";
+
+ // Generate file list and content for each test
+ for (const auto& test : allTests) {
+ std::string testFileName = formatv("generated_{0}_{1}", test.sourceDefName, test.testFileName);
+ os << "// File: " << testFileName << "\n";
+
+ os << "// --- BEGIN " << testFileName << " ---\n";
+ generateTestFile(test, os);
+ os << "// --- END " << testFileName << " ---\n\n";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Generator Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genLitTests("gen-lit-tests", "Generate LIT test files for `Testable` TableGen records",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ generateLitTests(records, os);
+ return false;
+ });
\ No newline at end of file
>From c312060a4d46e64847505397ff97e395eacdbbd4 Mon Sep 17 00:00:00 2001
From: Kshitij <kjain at d-matrix.ai>
Date: Sun, 2 Nov 2025 00:17:20 +0000
Subject: [PATCH 2/2] tmp
---
mlir/examples/toy/Ch2/CMakeLists.txt | 8 +-
mlir/include/mlir/IR/Testable.td | 6 -
mlir/include/mlir/Pass/PassBase.td | 2 +-
mlir/test/mlir-tblgen/gen-lit-tests.td | 100 +++++++--------
mlir/tools/mlir-tblgen/LitTestGen.cpp | 168 ++++++++++++++++++++-----
5 files changed, 189 insertions(+), 95 deletions(-)
diff --git a/mlir/examples/toy/Ch2/CMakeLists.txt b/mlir/examples/toy/Ch2/CMakeLists.txt
index 3fbff2fa2a679..4df06410c9261 100644
--- a/mlir/examples/toy/Ch2/CMakeLists.txt
+++ b/mlir/examples/toy/Ch2/CMakeLists.txt
@@ -1,6 +1,11 @@
# For a better template to copy, see examples/standalone
add_subdirectory(include)
+add_embedded_lit_tests(EmbeddedLitTestsGen
+ "include/toy/TestPasses.td"
+ "/disk1/kjain/workspace/llvm-project/mlir/examples/toy/Ch2/jkshtj"
+)
+
set(LLVM_LINK_COMPONENTS
Support
)
@@ -13,8 +18,9 @@ add_toy_chapter(toyc-ch2
DEPENDS
ToyCh2OpsIncGen
-
+ EmbeddedLitTestsGen
)
+
include_directories(include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch2
diff --git a/mlir/include/mlir/IR/Testable.td b/mlir/include/mlir/IR/Testable.td
index 15814ed1bd939..ca96e4f0b8b74 100644
--- a/mlir/include/mlir/IR/Testable.td
+++ b/mlir/include/mlir/IR/Testable.td
@@ -31,10 +31,4 @@ class LitTest<string name, code snippet, list<string> run = [], list<string> che
list<string> checkLines = check;
}
-// Base class for elements that can have auto-generated LIT tests
-class Testable {
- // List of LIT tests associated with this element
- list<LitTest> tests = [];
-}
-
#endif // TESTABLE
\ No newline at end of file
diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index 50ea44419ca24..d47780b735eea 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -64,7 +64,7 @@ class Statistic<string varName, string statName, string desc> {
// Pass
//===----------------------------------------------------------------------===//
-class PassBase<string passArg, string base> : Testable {
+class PassBase<string passArg, string base> {
// The command line argument of the pass.
string argument = passArg;
diff --git a/mlir/test/mlir-tblgen/gen-lit-tests.td b/mlir/test/mlir-tblgen/gen-lit-tests.td
index 40a03fb2b2d60..834695d4617a9 100644
--- a/mlir/test/mlir-tblgen/gen-lit-tests.td
+++ b/mlir/test/mlir-tblgen/gen-lit-tests.td
@@ -2,64 +2,58 @@
include "mlir/Pass/PassBase.td"
include "mlir/IR/Testable.td"
+include "mlir/IR/OpBase.td"
-def TestPassWithEmbeddedLitTests : Pass<"test-pass-with-embedded-lit-tests"> {
- let summary = "pass summary";
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "test";
+}
+
+def TestOp : Op<Test_Dialect, "test_op"> {
+ let summary = "test op with mlir_example code blocks";
let description = [{
- Pass description
+ This operation demonstrates the mlir_example feature for ops.
+
+ Basic usage:
+ ```mlir_example
+ func.func @foo(%arg0: i32) -> i32 {
+ %0 = test.test_op %arg0 : i32
+ return %0 : i32
+ }
+ ```
+
+ And some more examples -
+
+ ```mlir_example
+ func.func @foo1(%arg1: i32) -> i32 {
+ %0 = test.test_op %arg1 : i32
+ return %0 : i32
+ }
+ ```
}];
-
- let tests = [
- LitTest<
- "lit_test_file_1.mlir",
- [{
- func.func @test1() {
- return 42;
- }
- }],
- [
- "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
- ],
- [
- "// RANDOM-CHECK-LABEL: func.func @test1",
- ]
- >,
- LitTest<
- "lit_test_file_2.mlir",
- [{
- func.func @test2() {
- return 42;
- }
- }],
- [
- "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
- ],
- [
- "// RANDOM-CHECK-LABEL: func.func @test2",
- ]
- >,
- ];
-}
-// CHECK-LABEL: // Generated 2 LIT test files
-// CHECK: // Use the following files for LIT testing:
+ let arguments = (ins I32:$input);
+ let results = (outs I32:$output);
+}
-// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir
-// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
-// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
-// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
-// CHECK: func.func @test1() {
-// CHECK: return 42;
+// CHECK: // File: generated_TestOp_example_0.mlir
+// CHECK: // --- BEGIN generated_TestOp_example_0.mlir ---
+// CHECK: // RUN: mlir-opt %s --verify-roundtrip
+// CHECK: // Generated from TableGen definition: TestOp
+// CHECK: func.func @foo(%arg0: i32) -> i32 {
+// CHECK: %0 = test.test_op %arg0 : i32
+// CHECK: return %0 : i32
// CHECK: }
-// CHECK: // RANDOM-CHECK-LABEL: func.func @test1
-// CHECK: --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
+// CHECK: // --- END generated_TestOp_example_0.mlir ---
-// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir
-// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
-// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
-// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
-// CHECK: func.func @test2() {
-// CHECK: return 42;
+// CHECK: // File: generated_TestOp_example_1.mlir
+// CHECK: // --- BEGIN generated_TestOp_example_1.mlir ---
+// CHECK: // RUN: mlir-opt %s --verify-roundtrip
+// CHECK: // Generated from TableGen definition: TestOp
+// CHECK: func.func @bar(%arg0: i32, %arg1: i32) -> i32 {
+// CHECK: %0 = test.test_op %arg0 : i32
+// CHECK: %1 = test.test_op %arg1 : i32
+// CHECK: %2 = arith.addi %0, %1 : i32
+// CHECK: return %2 : i32
// CHECK: }
-// CHECK: // RANDOM-CHECK-LABEL: func.func @test2
-// CHECK: // --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
\ No newline at end of file
+// CHECK: // --- END generated_TestOp_example_1.mlir ---
\ No newline at end of file
diff --git a/mlir/tools/mlir-tblgen/LitTestGen.cpp b/mlir/tools/mlir-tblgen/LitTestGen.cpp
index 49a092fa9879f..a03ccdfb79d38 100644
--- a/mlir/tools/mlir-tblgen/LitTestGen.cpp
+++ b/mlir/tools/mlir-tblgen/LitTestGen.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -41,36 +42,136 @@ static llvm::cl::opt<std::string>
struct LitTest {
std::string sourceDefName;
std::string testFileName;
- std::string irSnippet;
+ std::string irSnippet;
llvm::SmallVector<std::string> runLines;
llvm::SmallVector<std::string> checkLines;
};
+/// Extract code snippets with mlir_example tag from a description field.
+/// Returns a vector of code snippets found within ```mlir_example ... ``` blocks.
+static llvm::SmallVector<std::string> extractMlirExamples(llvm::StringRef description) {
+ llvm::SmallVector<std::string> examples;
+
+ // Pattern to match ```mlir_example ... ``` code blocks
+ // [^\n]* matches rest of line after mlir_example
+ // \n matches the newline after the opening fence
+ // (.+?) captures the code content (non-greedy)
+ // ``` matches the closing fence
+ llvm::Regex codeBlockRegex("```mlir_example(.+)```");
+
+ llvm::StringRef remaining = description;
+ llvm::SmallVector<llvm::StringRef> matches;
+
+ while (codeBlockRegex.match(remaining, &matches)) {
+ if (matches.size() >= 2) {
+ // matches[1] contains the captured group (the code content)
+ std::string code = matches[1].str();
+
+ llvm::errs() << "DEBUG: Extracted raw code:\n[" << code << "]\n";
+
+ // Remove leading/trailing whitespace and comment markers (# prefix)
+ llvm::SmallVector<llvm::StringRef> lines;
+ llvm::StringRef codeRef(code);
+ codeRef.split(lines, '\n', -1, false);
+
+ std::string processedCode;
+ for (llvm::StringRef line : lines) {
+ line = line.ltrim();
+ // Remove leading # comment markers if present
+ if (line.starts_with("#")) {
+ line = line.drop_front(1).ltrim();
+ }
+ if (!line.empty() || !processedCode.empty()) {
+ processedCode += line.str() + "\n";
+ }
+ }
+
+ // // Remove trailing empty lines
+ // while (!processedCode.empty() && processedCode.back() == '\n') {
+ // size_t lastNewline = processedCode.find_last_not_of('\n');
+ // if (lastNewline == std::string::npos) {
+ // processedCode.clear();
+ // break;
+ // }
+ // processedCode = processedCode.substr(0, lastNewline + 1) + "\n";
+ // }
+
+ if (!processedCode.empty()) {
+ examples.push_back(processedCode);
+ }
+ }
+
+ // Move past this match to find the next one
+ size_t matchEnd = remaining.find("```", remaining.find("```mlir_example") + 15);
+ if (matchEnd == llvm::StringRef::npos)
+ break;
+ remaining = remaining.substr(matchEnd + 3);
+ }
+
+ return examples;
+}
+
static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *record,
llvm::StringRef dialectName = "") {
llvm::SmallVector<LitTest> tests;
-
- // Check if the record has a tests field
+
+ // Try to extract mlir_example code blocks from the description field
+ const llvm::RecordVal *descVal = record->getValue("description");
+ if (descVal) {
+ llvm::StringRef description = record->getValueAsString("description");
+ llvm::errs() << "DEBUG: Record: " << record->getName() << "\n";
+ llvm::errs() << "DEBUG: Description length: " << description.size() << "\n";
+ llvm::errs() << "DEBUG: Description content:\n" << description << "\n";
+ llvm::errs() << "DEBUG: ---\n";
+ if (!description.empty()) {
+ llvm::SmallVector<std::string> examples = extractMlirExamples(description);
+ llvm::errs() << "DEBUG: Found " << examples.size() << " examples\n";
+
+ // Create a LitTest for each extracted example
+ for (size_t i = 0; i < examples.size(); ++i) {
+ std::string testFileName;
+ if (examples.size() == 1) {
+ testFileName = "example.mlir";
+ } else {
+ testFileName = formatv("example_{0}.mlir", i);
+ }
+
+ // Generate default RUN line with --verify-roundtrip
+ llvm::SmallVector<std::string> runLines;
+ runLines.push_back("// RUN: mlir-opt %s --verify-roundtrip");
+
+ tests.push_back(LitTest {
+ record->getName().str(),
+ testFileName,
+ examples[i],
+ runLines,
+ {} // No CHECK lines by default
+ });
+ }
+ }
+ }
+
+ // Fall back to checking for the old tests field for backward compatibility
const llvm::RecordVal *testsVal = record->getValue("tests");
if (!testsVal)
return tests;
-
- const llvm::ListInit *testsList =
+
+ const llvm::ListInit *testsList =
llvm::dyn_cast_or_null<llvm::ListInit>(testsVal->getValue());
if (!testsList)
return tests;
-
+
for (const llvm::Init *init : testsList->getElements()) {
const llvm::DefInit *defInit = llvm::dyn_cast<llvm::DefInit>(init);
if (!defInit)
continue;
-
+
const llvm::Record *testRec = defInit->getDef();
-
+
// Extract fields from LitTest record
std::string name = testRec->getValueAsString("testFileName").str();
std::string irSnippet = testRec->getValueAsString("irSnippet").str();
-
+
llvm::SmallVector<std::string> runLines;
llvm::for_each(*testRec->getValueAsListInit("runLines"), [&](const llvm::Init *init) {
runLines.emplace_back(llvm::cast<llvm::StringInit>(init)->getValue());
@@ -83,31 +184,31 @@ static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *rec
tests.push_back(LitTest {
record->getName().str(),
- name,
- irSnippet,
- runLines,
- checkLines,
+ name,
+ irSnippet,
+ runLines,
+ checkLines,
});
}
-
+
return tests;
}
-/// Extract tests from passes
-static llvm::SmallVector<LitTest> extractPassTests(const RecordKeeper &records) {
+/// Extract tests from ops
+static llvm::SmallVector<LitTest> extractOpTests(const RecordKeeper &records) {
llvm::SmallVector<LitTest> tests;
-
- // Check if PassBase class exists before trying to get derived definitions
- if (records.getClass("PassBase")) {
- for (const llvm::Record *def : records.getAllDerivedDefinitions("PassBase")) {
+
+ // Check if Op class exists before trying to get derived definitions
+ if (records.getClass("Op")) {
+ for (const llvm::Record *def : records.getAllDerivedDefinitions("Op")) {
if (def->isAnonymous())
continue;
-
- auto passTests = extractTestsFromRecord(def, "passes");
- tests.insert(tests.end(), passTests.begin(), passTests.end());
+
+ auto opTests = extractTestsFromRecord(def, "ops");
+ tests.insert(tests.end(), opTests.begin(), opTests.end());
}
}
-
+
return tests;
}
@@ -132,26 +233,25 @@ static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) {
/// Main function to generate all IR test test files
static void generateLitTests(const RecordKeeper &records, raw_ostream &os) {
llvm::SmallVector<LitTest> allTests;
-
- // Extract tests from different definition types (only passes for now)
- auto passTests = extractPassTests(records);
-
- allTests.insert(allTests.end(), passTests.begin(), passTests.end());
-
+
+ // Extract tests from different definition types
+ auto opTests = extractOpTests(records);
+ allTests.insert(allTests.end(), opTests.begin(), opTests.end());
+
if (allTests.empty()) {
- os << "// No LitTest record found in any TableGen definition\n";
+ os << "// No mlir_example code blocks found in any TableGen definition\n";
return;
}
-
+
// Generate summary
os << "// Generated " << allTests.size() << " LIT test files\n";
os << "// Use the following files for LIT testing:\n\n";
-
+
// Generate file list and content for each test
for (const auto& test : allTests) {
std::string testFileName = formatv("generated_{0}_{1}", test.sourceDefName, test.testFileName);
os << "// File: " << testFileName << "\n";
-
+
os << "// --- BEGIN " << testFileName << " ---\n";
generateTestFile(test, os);
os << "// --- END " << testFileName << " ---\n\n";
More information about the Mlir-commits
mailing list