[Mlir-commits] [mlir] [mlir][transform] Support for multiple top-level transform ops (PR #69615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 09:51:56 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (martin-luecke)

<details>
<summary>Changes</summary>

This adds a flag to the `TransformDialectInterpreter` that relaxes the requirement for only a single top-level transform op.
This is useful for supporting transforms that take transform IR as payload.

This also aligns the function `findTopLevelTransform` [here](https://github.com/llvm/llvm-project/commit/7b0f4c9db55c355bffddf94d7710f40ee2c1e9db#diff-551f92bb609487ccf981daf9571f0f1b1703ab2330560a388a5f0d133e520be4L59) with its documentation:
In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a `nullptr`.

---
Full diff: https://github.com/llvm/llvm-project/pull/69615.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+13) 
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+13-11) 
- (added) mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir (+26) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..60eb48a764eb4b2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,24 @@ class TransformOptions {
     return *this;
   }
 
+  // Ensures that only a single top-level transform op is present in the IR.
+  TransformOptions &
+  enableEnforceSingleToplevelTransformOp(bool enable = true) {
+    enforceSingleToplevelTransformOp = enable;
+    return *this;
+  }
+
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
+  // Returns true if enforcing a single top-level transform op is requested.
+  bool getEnforceSingleToplevelTransformOp() const {
+    return enforceSingleToplevelTransformOp;
+  }
+
 private:
   bool expensiveChecksEnabled = true;
+  bool enforceSingleToplevelTransformOp = true;
 };
 
 /// Entry point to the Transform dialect infrastructure. Applies the
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
 /// found.
-static Operation *findTopLevelTransform(Operation *root,
-                                        StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+                      mlir::transform::TransformOptions options) {
   ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
-  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+  root->walk<WalkOrder::PreOrder>(
       [&](::mlir::transform::TransformOpInterface transformOp) {
         if (!transformOp
                  ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
           topLevelTransform = transformOp;
           return WalkResult::skip();
         }
-        auto diag = transformOp.emitError()
-                    << "more than one top-level transform op";
-        diag.attachNote(topLevelTransform.getLoc())
-            << "previous top-level transform op";
-        return WalkResult::interrupt();
+        if (options.getEnforceSingleToplevelTransformOp()) {
+          auto diag = transformOp.emitError()
+                      << "more than one top-level transform op";
+          diag.attachNote(topLevelTransform.getLoc())
+              << "previous top-level transform op";
+          return WalkResult::interrupt();
+        }
+        return WalkResult::skip();
       });
-  if (walkResult.wasInterrupted())
-    return nullptr;
   if (!topLevelTransform) {
     auto diag = root->emitError()
                 << "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
-                                  transformFileName.getArgStr())
+                                  transformFileName.getArgStr(), options)
           : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
                           debugTransformRootTag);
   if (!transformRoot)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{found get_parent_op}}
+  %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
     }
 
     options = options.enableExpensiveChecks(enableExpensiveChecks);
+    options = options.enableEnforceSingleToplevelTransformOp(
+        enforceSingleToplevelTransformOp);
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
       *this, "enable-expensive-checks", llvm::cl::init(false),
       llvm::cl::desc("perform expensive checks to better report errors in the "
                      "transform IR")};
+  Option<bool> enforceSingleToplevelTransformOp{
+      *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+      llvm::cl::desc("Ensure that only a single top-level transform op is "
+                     "present in the IR.")};
 
   Option<std::string> bindFirstExtraToOps{
       *this, "bind-first-extra-to-ops",

``````````

</details>


https://github.com/llvm/llvm-project/pull/69615


More information about the Mlir-commits mailing list