[Mlir-commits] [mlir] [mlir][transform] Support for multiple top-level transform ops (PR #69615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 10:22:06 PDT 2023


https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/69615

>From 23add3cff2354e33fc6094eb1e7bd97f9ccb88d1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20L=C3=BCcke?= <martin.luecke at ed.ac.uk>
Date: Thu, 19 Oct 2023 17:21:24 +0000
Subject: [PATCH] [mlir][transform] Support for multiple top-level transform
 ops

---
 .../Transform/IR/TransformInterfaces.h        | 12 +++++++++
 .../TransformInterpreterPassBase.cpp          | 24 +++++++++--------
 ...st-interpreter-multiple-top-level-ops.mlir | 26 +++++++++++++++++++
 .../TestTransformDialectInterpreter.cpp       |  6 +++++
 4 files changed, 57 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..5fcde11d52f03d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,23 @@ class TransformOptions {
     return *this;
   }
 
+  // Ensures that only a single top-level transform op is present in the IR.
+  TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) {
+    enforceSingleToplevelTransformOp = enable;
+    return *this;
+  }
+
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
+  // Returns true if enforcing a single top-level transform op is requested.
+  bool getEnforceSingleToplevelTransformOp() const {
+    return enforceSingleToplevelTransformOp;
+  }
+
 private:
   bool expensiveChecksEnabled = true;
+  bool enforceSingleToplevelTransformOp = true;
 };
 
 /// Entry point to the Transform dialect infrastructure. Applies the
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
 /// found.
-static Operation *findTopLevelTransform(Operation *root,
-                                        StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+                      mlir::transform::TransformOptions options) {
   ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
-  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+  root->walk<WalkOrder::PreOrder>(
       [&](::mlir::transform::TransformOpInterface transformOp) {
         if (!transformOp
                  ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
           topLevelTransform = transformOp;
           return WalkResult::skip();
         }
-        auto diag = transformOp.emitError()
-                    << "more than one top-level transform op";
-        diag.attachNote(topLevelTransform.getLoc())
-            << "previous top-level transform op";
-        return WalkResult::interrupt();
+        if (options.getEnforceSingleToplevelTransformOp()) {
+          auto diag = transformOp.emitError()
+                      << "more than one top-level transform op";
+          diag.attachNote(topLevelTransform.getLoc())
+              << "previous top-level transform op";
+          return WalkResult::interrupt();
+        }
+        return WalkResult::skip();
       });
-  if (walkResult.wasInterrupted())
-    return nullptr;
   if (!topLevelTransform) {
     auto diag = root->emitError()
                 << "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
-                                  transformFileName.getArgStr())
+                                  transformFileName.getArgStr(), options)
           : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
                           debugTransformRootTag);
   if (!transformRoot)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{found get_parent_op}}
+  %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
     }
 
     options = options.enableExpensiveChecks(enableExpensiveChecks);
+    options = options.enableEnforceSingleToplevelTransformOp(
+        enforceSingleToplevelTransformOp);
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
       *this, "enable-expensive-checks", llvm::cl::init(false),
       llvm::cl::desc("perform expensive checks to better report errors in the "
                      "transform IR")};
+  Option<bool> enforceSingleToplevelTransformOp{
+      *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+      llvm::cl::desc("Ensure that only a single top-level transform op is "
+                     "present in the IR.")};
 
   Option<std::string> bindFirstExtraToOps{
       *this, "bind-first-extra-to-ops",



More information about the Mlir-commits mailing list