[Mlir-commits] [mlir] 9ccf01f - [mlir][transform] Support for multiple top-level transform ops (#69615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 02:42:25 PDT 2023


Author: martin-luecke
Date: 2023-10-20T11:42:20+02:00
New Revision: 9ccf01fbf711f4f9b0aa82d5732d778ea169fa88

URL: https://github.com/llvm/llvm-project/commit/9ccf01fbf711f4f9b0aa82d5732d778ea169fa88
DIFF: https://github.com/llvm/llvm-project/commit/9ccf01fbf711f4f9b0aa82d5732d778ea169fa88.diff

LOG: [mlir][transform] Support for multiple top-level transform ops (#69615)

This adds a flag to the `TransformDialectInterpreter` that relaxes the
requirement for only a single top-level transform op.
This is useful for supporting transforms that take transform IR as
payload.

This also aligns the function `findTopLevelTransform`
[here](https://github.com/llvm/llvm-project/commit/7b0f4c9db55c355bffddf94d7710f40ee2c1e9db#diff-551f92bb609487ccf981daf9571f0f1b1703ab2330560a388a5f0d133e520be4L59)
with its documentation:
In the presence of multiple top-level transform ops it now correctly
returns the first of them after reporting the error instead of returning
a `nullptr`.

Added: 
    mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..5fcde11d52f03d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,23 @@ class TransformOptions {
     return *this;
   }
 
+  // Ensures that only a single top-level transform op is present in the IR.
+  TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) {
+    enforceSingleToplevelTransformOp = enable;
+    return *this;
+  }
+
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
+  // Returns true if enforcing a single top-level transform op is requested.
+  bool getEnforceSingleToplevelTransformOp() const {
+    return enforceSingleToplevelTransformOp;
+  }
+
 private:
   bool expensiveChecksEnabled = true;
+  bool enforceSingleToplevelTransformOp = true;
 };
 
 /// Entry point to the Transform dialect infrastructure. Applies the

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
 /// found.
-static Operation *findTopLevelTransform(Operation *root,
-                                        StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+                      mlir::transform::TransformOptions options) {
   ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
-  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+  root->walk<WalkOrder::PreOrder>(
       [&](::mlir::transform::TransformOpInterface transformOp) {
         if (!transformOp
                  ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
           topLevelTransform = transformOp;
           return WalkResult::skip();
         }
-        auto diag = transformOp.emitError()
-                    << "more than one top-level transform op";
-        diag.attachNote(topLevelTransform.getLoc())
-            << "previous top-level transform op";
-        return WalkResult::interrupt();
+        if (options.getEnforceSingleToplevelTransformOp()) {
+          auto diag = transformOp.emitError()
+                      << "more than one top-level transform op";
+          diag.attachNote(topLevelTransform.getLoc())
+              << "previous top-level transform op";
+          return WalkResult::interrupt();
+        }
+        return WalkResult::skip();
       });
-  if (walkResult.wasInterrupted())
-    return nullptr;
   if (!topLevelTransform) {
     auto diag = root->emitError()
                 << "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
-                                  transformFileName.getArgStr())
+                                  transformFileName.getArgStr(), options)
           : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
                           debugTransformRootTag);
   if (!transformRoot)

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{found get_parent_op}}
+  %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
     }
 
     options = options.enableExpensiveChecks(enableExpensiveChecks);
+    options = options.enableEnforceSingleToplevelTransformOp(
+        enforceSingleToplevelTransformOp);
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
       *this, "enable-expensive-checks", llvm::cl::init(false),
       llvm::cl::desc("perform expensive checks to better report errors in the "
                      "transform IR")};
+  Option<bool> enforceSingleToplevelTransformOp{
+      *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+      llvm::cl::desc("Ensure that only a single top-level transform op is "
+                     "present in the IR.")};
 
   Option<std::string> bindFirstExtraToOps{
       *this, "bind-first-extra-to-ops",


        


More information about the Mlir-commits mailing list