[Mlir-commits] [mlir] [mlir][transform] Support for multiple top-level transform ops (PR #69615)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 09:51:56 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (martin-luecke)
<details>
<summary>Changes</summary>
This adds a flag to the `TransformDialectInterpreter` that relaxes the requirement for only a single top-level transform op.
This is useful for supporting transforms that take transform IR as payload.
This also aligns the function `findTopLevelTransform` [here](https://github.com/llvm/llvm-project/commit/7b0f4c9db55c355bffddf94d7710f40ee2c1e9db#diff-551f92bb609487ccf981daf9571f0f1b1703ab2330560a388a5f0d133e520be4L59) with its documentation:
In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a `nullptr`.
---
Full diff: https://github.com/llvm/llvm-project/pull/69615.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+13)
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+13-11)
- (added) mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir (+26)
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..60eb48a764eb4b2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,24 @@ class TransformOptions {
return *this;
}
+ // Ensures that only a single top-level transform op is present in the IR.
+ TransformOptions &
+ enableEnforceSingleToplevelTransformOp(bool enable = true) {
+ enforceSingleToplevelTransformOp = enable;
+ return *this;
+ }
+
/// Returns true if the expensive checks are requested.
bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
+ // Returns true if enforcing a single top-level transform op is requested.
+ bool getEnforceSingleToplevelTransformOp() const {
+ return enforceSingleToplevelTransformOp;
+ }
+
private:
bool expensiveChecksEnabled = true;
+ bool enforceSingleToplevelTransformOp = true;
};
/// Entry point to the Transform dialect infrastructure. Applies the
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
/// found.
-static Operation *findTopLevelTransform(Operation *root,
- StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+ mlir::transform::TransformOptions options) {
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
- WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+ root->walk<WalkOrder::PreOrder>(
[&](::mlir::transform::TransformOpInterface transformOp) {
if (!transformOp
->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
topLevelTransform = transformOp;
return WalkResult::skip();
}
- auto diag = transformOp.emitError()
- << "more than one top-level transform op";
- diag.attachNote(topLevelTransform.getLoc())
- << "previous top-level transform op";
- return WalkResult::interrupt();
+ if (options.getEnforceSingleToplevelTransformOp()) {
+ auto diag = transformOp.emitError()
+ << "more than one top-level transform op";
+ diag.attachNote(topLevelTransform.getLoc())
+ << "previous top-level transform op";
+ return WalkResult::interrupt();
+ }
+ return WalkResult::skip();
});
- if (walkResult.wasInterrupted())
- return nullptr;
if (!topLevelTransform) {
auto diag = root->emitError()
<< "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *transformRoot =
debugTransformRootTag.empty()
? findTopLevelTransform(transformContainer,
- transformFileName.getArgStr())
+ transformFileName.getArgStr(), options)
: findOpWithTag(transformContainer, kTransformDialectTagAttrName,
debugTransformRootTag);
if (!transformRoot)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{found get_parent_op}}
+ %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
}
options = options.enableExpensiveChecks(enableExpensiveChecks);
+ options = options.enableEnforceSingleToplevelTransformOp(
+ enforceSingleToplevelTransformOp);
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
*this, "enable-expensive-checks", llvm::cl::init(false),
llvm::cl::desc("perform expensive checks to better report errors in the "
"transform IR")};
+ Option<bool> enforceSingleToplevelTransformOp{
+ *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+ llvm::cl::desc("Ensure that only a single top-level transform op is "
+ "present in the IR.")};
Option<std::string> bindFirstExtraToOps{
*this, "bind-first-extra-to-ops",
``````````
</details>
https://github.com/llvm/llvm-project/pull/69615
More information about the Mlir-commits
mailing list