[Mlir-commits] [mlir] Transform Interpreter: Prefer entry points in current module (PR #151323)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Jul 30 05:50:38 PDT 2025
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/151323
The transform interpreter previously looked for the entry point using a recursive walk in pre-order. This makes it so that any named_sequence operation with an arbitrary level of nested-ness will be used as the entry point for the transform interpreter as long as it is placed before another one.
This change makes it so that code like the one reported in https://github.com/llvm/llvm-project/issues/119578 works as expected.
Closes #119578
Some comments: alternatively, it would also be possible to solve this issue in a slightly more elegant manner. We could define a new walker iterator that iterates through the operations in a bread first search.
>From 88397999b92a54e36d9ae5de621cd1203bf7bba1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 30 Jul 2025 08:30:47 -0400
Subject: [PATCH] Prefer entry_points in current module.
The transform interpreter previously looked for the entry point
using a recursive walk in pre-order. This makes it so that any
named_sequence operation with an arbitrary level of nested-ness
will be used as the entry point for the transform interpreter as
long as it is placed before another one.
This change makes it so that code like the one reported in
https://github.com/llvm/llvm-project/issues/119578 works as expected.
---
.../Transforms/TransformInterpreterUtils.cpp | 24 +++++++++++++++----
.../Transform/interpreter-entry-point-2.mlir | 24 +++++++++++++++++++
2 files changed, 43 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 35ace1b2e0c3a..dc597d6ef2a1b 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -129,15 +129,29 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
l.push_back(module);
for (Operation *op : l) {
transform::TransformOpInterface transform = nullptr;
- op->walk<WalkOrder::PreOrder>(
- [&](transform::NamedSequenceOp namedSequenceOp) {
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ auto namedSequenceOps = block.getOps<transform::NamedSequenceOp>();
+ for (transform::NamedSequenceOp namedSequenceOp : namedSequenceOps) {
if (namedSequenceOp.getSymName() == entryPoint) {
transform = cast<transform::TransformOpInterface>(
namedSequenceOp.getOperation());
- return WalkResult::interrupt();
+ break;
}
- return WalkResult::advance();
- });
+ }
+ }
+ }
+ if (!transform) {
+ op->walk<WalkOrder::PreOrder>(
+ [&](transform::NamedSequenceOp namedSequenceOp) {
+ if (namedSequenceOp.getSymName() == entryPoint) {
+ transform = cast<transform::TransformOpInterface>(
+ namedSequenceOp.getOperation());
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
if (transform)
return transform;
}
diff --git a/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
new file mode 100644
index 0000000000000..e3e901a7eaf02
--- /dev/null
+++ b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -transform-interpreter \
+// RUN: -split-input-file -verify-diagnostics | FileCheck %s
+
+module @td_module_4 attributes {transform.with_named_sequence} {
+ module @foo_module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+ // CHECK: IR printer: foo_module top-level
+ transform.print {name="foo_module"}
+ transform.yield
+ }
+ }
+ module @bar_module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+ // CHECK: IR printer: bar_module top-level
+ transform.print {name="bar_module"}
+ transform.yield
+ }
+ }
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+ transform.include @foo_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ transform.include @bar_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list