[Mlir-commits] [mlir] [MLIR][Transform] Prefer entry points in current module (PR #151323)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Jul 31 07:56:26 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/151323

>From 88397999b92a54e36d9ae5de621cd1203bf7bba1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 30 Jul 2025 08:30:47 -0400
Subject: [PATCH 1/6] Prefer entry_points in current module.

The transform interpreter previously looked for the entry point
using a recursive walk in pre-order. This makes it so that any
named_sequence operation with an arbitrary level of nested-ness
will be used as the entry point for the transform interpreter as
long as it is placed before another one.

This change makes it so that code like the one reported in
https://github.com/llvm/llvm-project/issues/119578 works as expected.
---
 .../Transforms/TransformInterpreterUtils.cpp  | 24 +++++++++++++++----
 .../Transform/interpreter-entry-point-2.mlir  | 24 +++++++++++++++++++
 2 files changed, 43 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 35ace1b2e0c3a..dc597d6ef2a1b 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -129,15 +129,29 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
     l.push_back(module);
   for (Operation *op : l) {
     transform::TransformOpInterface transform = nullptr;
-    op->walk<WalkOrder::PreOrder>(
-        [&](transform::NamedSequenceOp namedSequenceOp) {
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        auto namedSequenceOps = block.getOps<transform::NamedSequenceOp>();
+        for (transform::NamedSequenceOp namedSequenceOp : namedSequenceOps) {
           if (namedSequenceOp.getSymName() == entryPoint) {
             transform = cast<transform::TransformOpInterface>(
                 namedSequenceOp.getOperation());
-            return WalkResult::interrupt();
+            break;
           }
-          return WalkResult::advance();
-        });
+        }
+      }
+    }
+    if (!transform) {
+      op->walk<WalkOrder::PreOrder>(
+          [&](transform::NamedSequenceOp namedSequenceOp) {
+            if (namedSequenceOp.getSymName() == entryPoint) {
+              transform = cast<transform::TransformOpInterface>(
+                  namedSequenceOp.getOperation());
+              return WalkResult::interrupt();
+            }
+            return WalkResult::advance();
+          });
+    }
     if (transform)
       return transform;
   }
diff --git a/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
new file mode 100644
index 0000000000000..e3e901a7eaf02
--- /dev/null
+++ b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -transform-interpreter \
+// RUN:   -split-input-file -verify-diagnostics | FileCheck %s
+
+module @td_module_4 attributes {transform.with_named_sequence} {
+  module @foo_module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+      // CHECK: IR printer: foo_module top-level
+      transform.print {name="foo_module"}
+      transform.yield
+    }
+  }
+  module @bar_module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+      // CHECK: IR printer: bar_module top-level
+      transform.print {name="bar_module"}
+      transform.yield
+    }
+  }
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
+    transform.include @foo_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
+    transform.include @bar_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}

>From 4cbdc979ca6b9a1d954ca29fc562864998693030 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 31 Jul 2025 09:54:46 -0400
Subject: [PATCH 2/6] Outline strategies for findTransformEntryPoint.

---
 .../Transforms/TransformInterpreterUtils.cpp  | 63 ++++++++++++-------
 1 file changed, 39 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index dc597d6ef2a1b..8c902f8de786e 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -121,6 +121,41 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
       ->getLibraryModule();
 }
 
+namespace {
+
+transform::TransformOpInterface
+findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region.getBlocks()) {
+      auto namedSequenceOps = block.getOps<transform::NamedSequenceOp>();
+      for (transform::NamedSequenceOp namedSequenceOp : namedSequenceOps) {
+        if (namedSequenceOp.getSymName() == entryPoint) {
+          return cast<transform::TransformOpInterface>(
+              namedSequenceOp.getOperation());
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+transform::TransformOpInterface
+findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
+  transform::TransformOpInterface transform = nullptr;
+  op->walk<WalkOrder::PreOrder>(
+      [&](transform::NamedSequenceOp namedSequenceOp) {
+        if (namedSequenceOp.getSymName() == entryPoint) {
+          transform = cast<transform::TransformOpInterface>(
+              namedSequenceOp.getOperation());
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+  return transform;
+}
+
+} // namespace
+
 transform::TransformOpInterface
 transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
                                            StringRef entryPoint) {
@@ -128,30 +163,10 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
   if (module)
     l.push_back(module);
   for (Operation *op : l) {
-    transform::TransformOpInterface transform = nullptr;
-    for (Region &region : op->getRegions()) {
-      for (Block &block : region.getBlocks()) {
-        auto namedSequenceOps = block.getOps<transform::NamedSequenceOp>();
-        for (transform::NamedSequenceOp namedSequenceOp : namedSequenceOps) {
-          if (namedSequenceOp.getSymName() == entryPoint) {
-            transform = cast<transform::TransformOpInterface>(
-                namedSequenceOp.getOperation());
-            break;
-          }
-        }
-      }
-    }
-    if (!transform) {
-      op->walk<WalkOrder::PreOrder>(
-          [&](transform::NamedSequenceOp namedSequenceOp) {
-            if (namedSequenceOp.getSymName() == entryPoint) {
-              transform = cast<transform::TransformOpInterface>(
-                  namedSequenceOp.getOperation());
-              return WalkResult::interrupt();
-            }
-            return WalkResult::advance();
-          });
-    }
+    TransformOpInterface transform =
+        findTransformEntryPointNonRecursive(op, entryPoint);
+    if (!transform)
+      transform = findTransformEntryPointRecursive(op, entryPoint);
     if (transform)
       return transform;
   }

>From a69dd77e968d2c477fdc2b0951fd0bf6a4383475 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 31 Jul 2025 09:59:35 -0400
Subject: [PATCH 3/6] Outline findTransformEntryPointInOp.

---
 .../Transforms/TransformInterpreterUtils.cpp        | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 8c902f8de786e..f4eca8370359e 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -154,6 +154,15 @@ findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
   return transform;
 }
 
+transform::TransformOpInterface
+findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
+  transform::TransformOpInterface transform =
+      findTransformEntryPointNonRecursive(op, entryPoint);
+  if (!transform)
+    transform = findTransformEntryPointRecursive(op, entryPoint);
+  return transform;
+}
+
 } // namespace
 
 transform::TransformOpInterface
@@ -164,9 +173,7 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
     l.push_back(module);
   for (Operation *op : l) {
     TransformOpInterface transform =
-        findTransformEntryPointNonRecursive(op, entryPoint);
-    if (!transform)
-      transform = findTransformEntryPointRecursive(op, entryPoint);
+        findTransformEntryPointInOp(op, entryPoint);
     if (transform)
       return transform;
   }

>From fe8ee9b071fbbc1304070fb58a3df448f7a0d5b4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 31 Jul 2025 10:07:14 -0400
Subject: [PATCH 4/6] Add comments.

---
 .../Transforms/TransformInterpreterUtils.cpp  | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index f4eca8370359e..c893d267763a1 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -154,6 +154,41 @@ findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
   return transform;
 }
 
+// Will look for the transform's entry point favouring NamedSequenceOps
+// ops that exist within the operation without the need for nesting.
+// If no operation exists in the blocks owned by op, then it will recursively
+// walk the op in preorder and find the first NamedSequenceOp that matches
+// the entry point's name.
+//
+// This allows for the following two use cases:
+// 1. op is a module annotated with the transform.with_named_sequence attribute
+//    that has an entry point in its block. E.g.,
+//
+//    ```mlir
+//    module {transform.with_named_sequence} {
+//      transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
+//      () {
+//        transform.yield
+//      }
+//    }
+//    ```
+//
+// 2. op is a program which contains a nested module annotated with the
+//    transform.with_named_sequence attribute. E.g.,
+//
+//    ```mlir
+//    module {
+//      func.func @foo () {
+//      }
+//
+//      module {transform.with_named_sequence} {
+//        transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
+//        -> () {
+//          transform.yield
+//        }
+//      }
+//    }
+//    ```
 transform::TransformOpInterface
 findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
   transform::TransformOpInterface transform =

>From c204fd14ae715e47dbde7849f1ec5df866a6bc03 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 31 Jul 2025 10:36:04 -0400
Subject: [PATCH 5/6] Apply suggestions from review.

---
 .../Transforms/TransformInterpreterUtils.cpp        | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index c893d267763a1..9ab484ff68078 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -121,14 +121,11 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
       ->getLibraryModule();
 }
 
-namespace {
-
-transform::TransformOpInterface
+static transform::TransformOpInterface
 findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
-      auto namedSequenceOps = block.getOps<transform::NamedSequenceOp>();
-      for (transform::NamedSequenceOp namedSequenceOp : namedSequenceOps) {
+      for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) {
         if (namedSequenceOp.getSymName() == entryPoint) {
           return cast<transform::TransformOpInterface>(
               namedSequenceOp.getOperation());
@@ -139,7 +136,7 @@ findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
   return nullptr;
 }
 
-transform::TransformOpInterface
+static transform::TransformOpInterface
 findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
   transform::TransformOpInterface transform = nullptr;
   op->walk<WalkOrder::PreOrder>(
@@ -189,7 +186,7 @@ findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
 //      }
 //    }
 //    ```
-transform::TransformOpInterface
+static transform::TransformOpInterface
 findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
   transform::TransformOpInterface transform =
       findTransformEntryPointNonRecursive(op, entryPoint);
@@ -198,8 +195,6 @@ findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
   return transform;
 }
 
-} // namespace
-
 transform::TransformOpInterface
 transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
                                            StringRef entryPoint) {

>From d689e24a2880ab20f3878c15e4e6c7c01ee3f020 Mon Sep 17 00:00:00 2001
From: Erick Ochoa Lopez <eochoalo at amd.com>
Date: Thu, 31 Jul 2025 10:56:15 -0400
Subject: [PATCH 6/6] Update
 mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
index e3e901a7eaf02..5c97c4c25ea41 100644
--- a/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
+++ b/mlir/test/Dialect/Transform/interpreter-entry-point-2.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter \
-// RUN:   -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 module @td_module_4 attributes {transform.with_named_sequence} {
   module @foo_module attributes {transform.with_named_sequence} {



More information about the Mlir-commits mailing list