[Mlir-commits] [mlir] [mlir] update remaining transform tests to main pass (PR #81279)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 9 09:04:00 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
Use the main transform interpreter pass instead of the test pass. The only tests that are not updated are specific to the operation of the test pass.
---
Patch is 64.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81279.diff
15 Files Affected:
- (modified) mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp (+73-22)
- (modified) mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir (+15-10)
- (modified) mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir (+8-6)
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir (+42-21)
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir (+26-13)
- (modified) mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir (+21-11)
- (modified) mlir/test/Dialect/Transform/test-interpreter-debug.mlir (+17-41)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-concurrent.mlir (+3-1)
- (modified) mlir/test/Dialect/Transform/test-interpreter-external.mlir (+3-1)
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (-1)
- (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+45-33)
- (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+206-178)
- (modified) mlir/test/Dialect/Transform/test-pdl-extension.mlir (+39-29)
- (modified) mlir/test/Dialect/Transform/transform-state-extension.mlir (+54-48)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir (+10-8)
``````````diff
diff --git a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
index 5073234a7e35e9..7adf223f3440a5 100644
--- a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
@@ -50,12 +50,79 @@ static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
return WalkResult::interrupt();
});
+ if (!target) {
+ passRoot->emitError()
+ << "could not find the operation with transform.target_tag=\"" << tag
+ << "\" attribute";
+ return nullptr;
+ }
+
return walkResult.wasInterrupted() ? nullptr : target;
}
namespace {
class InterpreterPass
: public transform::impl::InterpreterPassBase<InterpreterPass> {
+ // Parses the pass arguments to bind trailing arguments of the entry point.
+ std::optional<RaggedArray<transform::MappedValue>>
+ parseArguments(Operation *payloadRoot) {
+ MLIRContext *context = payloadRoot->getContext();
+
+ SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
+ trailingBindings.resize(debugBindTrailingArgs.size());
+
+ // Construct lists of op names to match.
+ SmallVector<std::optional<OperationName>> debugBindNames;
+ debugBindNames.reserve(debugBindTrailingArgs.size());
+ for (auto &&[position, nameString] :
+ llvm::enumerate(debugBindTrailingArgs)) {
+ StringRef name = nameString;
+
+ // Parse the integer literals.
+ if (name.starts_with("#")) {
+ debugBindNames.push_back(std::nullopt);
+ StringRef lhs = "";
+ StringRef rhs = name.drop_front();
+ do {
+ std::tie(lhs, rhs) = rhs.split(';');
+ int64_t value;
+ if (lhs.getAsInteger(10, value)) {
+ emitError(UnknownLoc::get(context))
+ << "couldn't parse integer pass argument " << name;
+ return std::nullopt;
+ }
+ trailingBindings[position].push_back(
+ Builder(context).getI64IntegerAttr(value));
+ } while (!rhs.empty());
+ } else if (name.starts_with("^")) {
+ debugBindNames.emplace_back(OperationName(name.drop_front(), context));
+ } else {
+ debugBindNames.emplace_back(OperationName(name, context));
+ }
+ }
+
+ // Collect operations or results for extra bindings.
+ payloadRoot->walk([&](Operation *payload) {
+ for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
+ if (!name || payload->getName() != *name)
+ continue;
+
+ if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
+ .starts_with("^")) {
+ llvm::append_range(trailingBindings[position], payload->getResults());
+ } else {
+ trailingBindings[position].push_back(payload);
+ }
+ }
+ });
+
+ RaggedArray<transform::MappedValue> bindings;
+ bindings.push_back(ArrayRef<Operation *>{payloadRoot});
+ for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
+ bindings.push_back(std::move(trailing));
+ return bindings;
+ }
+
public:
using Base::Base;
@@ -67,34 +134,18 @@ class InterpreterPass
findPayloadRoot(getOperation(), debugPayloadRootTag);
if (!payloadRoot)
return signalPassFailure();
- auto debugBindNames = llvm::map_to_vector(
- debugBindTrailingArgs,
- [&](const std::string &name) { return OperationName(name, context); });
- SmallVector<SmallVector<Operation *>, 2> trailingBindings;
- trailingBindings.resize(debugBindNames.size());
- payloadRoot->walk([&](Operation *payload) {
- for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
- if (payload->getName() == name)
- trailingBindings[position].push_back(payload);
- }
- });
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
getOperation(), transformModule, entryPoint);
- if (!transformEntryPoint) {
- getOperation()->emitError()
- << "could not find transform entry point: " << entryPoint
- << " in either payload or transform module";
+ if (!transformEntryPoint)
return signalPassFailure();
- }
-
- RaggedArray<transform::MappedValue> bindings;
- bindings.push_back(ArrayRef<Operation *>{payloadRoot});
- for (SmallVector<Operation *> &trailing : trailingBindings)
- bindings.push_back(std::move(trailing));
+ std::optional<RaggedArray<transform::MappedValue>> bindings =
+ parseArguments(payloadRoot);
+ if (!bindings)
+ return signalPassFailure();
if (failed(transform::applyTransformNamedSequence(
- bindings,
+ *bindings,
cast<transform::TransformOpInterface>(transformEntryPoint),
transformModule,
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
diff --git a/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir b/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
index 316b90f85236e4..255ff5f31ed3f5 100644
--- a/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
+++ b/mlir/test/Dialect/Transform/include/test-interpreter-external-concurrent-source.mlir
@@ -1,16 +1,21 @@
// RUN: mlir-opt %s
// No need to check anything else than parsing here, this is being used by another test as data.
-transform.with_pdl_patterns {
-^bb0(%arg0: !transform.any_op):
- pdl.pattern @func_return : benefit(1) {
- %0 = pdl.operation "func.return"
- pdl.rewrite %0 with "transform.dialect"
- }
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.with_pdl_patterns %root : !transform.any_op {
+ ^bb0(%arg0: !transform.any_op):
+ pdl.pattern @func_return : benefit(1) {
+ %0 = pdl.operation "func.return"
+ pdl.rewrite %0 with "transform.dialect"
+ }
- sequence %arg0 : !transform.any_op failures(propagate) {
- ^bb1(%arg1: !transform.any_op):
- %0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
- transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
+ sequence %arg0 : !transform.any_op failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
+ transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
+ }
+ }
+ transform.yield
}
}
diff --git a/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir b/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
index 5956c86ebbe4b2..f6b7f787cc2c38 100644
--- a/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
+++ b/mlir/test/Dialect/Transform/include/test-interpreter-external-source.mlir
@@ -1,11 +1,13 @@
// RUN: mlir-opt %s
// No need to check anything else than parsing here, this is being used by another test as data.
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
- transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
- transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
- ^bb1(%arg1: !transform.any_op):
- transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
+ transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
+ ^bb1(%arg1: !transform.any_op):
+ transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
+ }
+ transform.yield
}
}
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
index 9a7e7ca2f9536e..1c018b1b1f7796 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
@@ -1,10 +1,15 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
-// RUN: --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
+// RUN: debug-bind-trailing-args=func.func,func.return})" \
+// RUN: --split-input-file --verify-diagnostics
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
- transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
- transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op,
+ %arg2: !transform.any_op) {
+ transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
+ transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
+ transform.yield
+ }
}
// expected-remark @below {{first extra}}
@@ -26,9 +31,13 @@ func.func @bar(%arg0: i1) {
// -----
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
- // expected-error @above {{wrong kind of value provided for top-level parameter}}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op,
+ %arg2: !transform.param<i64>) {
+ // expected-error @above {{wrong kind of value provided for top-level parameter}}
+ transform.yield
+ }
}
func.func @foo() {
@@ -37,9 +46,13 @@ func.func @foo() {
// -----
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
- // expected-error @above {{wrong kind of value provided for the top-level value handle}}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op,
+ %arg2: !transform.any_value) {
+ // expected-error @above {{wrong kind of value provided for the top-level value handle}}
+ transform.yield
+ }
}
func.func @foo() {
@@ -48,19 +61,27 @@ func.func @foo() {
// -----
-// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op) {
+ transform.yield
+ }
}
// -----
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
- transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
- ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
- transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
- transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op,
+ %arg2: !transform.any_op) {
+ transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+ ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+ transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
+ transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
+ }
+ transform.yield
}
}
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
index f59a4b6d4ccc32..6486bcae3294e4 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
@@ -1,24 +1,37 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
+// RUN: debug-bind-trailing-args=#1;2;3,#42;45})' \
// RUN: --split-input-file --verify-diagnostics
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
- // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
- transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
- // expected-remark @below {{42 : i64, 45 : i64}}
- transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.param<i64>,
+ %arg2: !transform.param<i64>) {
+ // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
+ transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
+ // expected-remark @below {{42 : i64, 45 : i64}}
+ transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
+ transform.yield
+ }
}
// -----
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
- // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_op,
+ // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+ %arg2: !transform.param<i64>) {
+ transform.yield
+ }
}
// -----
-// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.param<i64>,
+ %arg2: !transform.param<i64>, %arg3: !transform.param<i64>) {
+ transform.yield
+ }
}
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
index 38d7e28697774d..dcc1079267dc7c 100644
--- a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
+++ b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
+// RUN: debug-bind-trailing-args=^test.some_returning_op,^test.some_other_returning_op})' \
// RUN: --split-input-file --verify-diagnostics
// Note that diagnostic checker will merge two diagnostics with the same message
@@ -21,10 +22,14 @@
// expected-note @below {{value handle points to an op result #1}}
%2:2 = "test.some_other_returning_op"() : () -> (f32, f64)
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value):
- transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
- transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op, %arg1: !transform.any_value,
+ %arg2: !transform.any_value) {
+ transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
+ transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
+ transform.yield
+ }
}
// -----
@@ -32,14 +37,19 @@ transform.sequence failures(propagate) {
%0:2 = "test.some_returning_op"() : () -> (i32, i64)
%1 = "test.some_returning_op"() : () -> index
-transform.sequence failures(propagate) {
-// expected-error @below {{wrong kind of value provided for top-level operation handle}}
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ // expected-error @below {{wrong kind of value provided for top-level operation handle}}
+ %arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value) {
+ transform.yield
+ }
}
// -----
-// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
-transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value):
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op, %arg1: !transform.any_value) {
+ transform.yield
+ }
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
index c7dad582dd432c..99301ea23c6f8d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
@@ -1,19 +1,21 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=payload debug-transform-root-tag=transform})" \
-// RUN: --allow-unregistered-dialect --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
+// RUN: debug-payload-root-tag=payload \
+// RUN: entry-point=transform})" \
+// RUN: --allow-unregistered-dialect --split-input-file --verify-diagnostics
// expected-error @below {{could not find the operation with transform.target_tag="payload" attribute}}
-module {
- transform.sequence failures(suppress) {
- ^bb0(%arg0: !transform.any_op):
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @transform(%arg0: !transform.any_op) {
+ transform.yield
}
}
// -----
-// expected-error @below {{could not find the operation with transform.target_tag="transform" attribute}}
-module {
- transform.sequence failures(suppress) {
- ^bb0(%arg0: !transform.any_op):
+// expected-error @below {{could not find a nested named sequence with name: transform}}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @not_transform(%arg0: !transform.any_op) {
+ transform.yield
}
module attributes {transform.target_tag="payload"} {}
@@ -21,42 +23,16 @@ module {
// -----
-// expected-error @below {{more than one operation with transform.target_tag="transform" attribute}}
-module {
- // expected-note @below {{first operation}}
- transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
- ^bb0(%arg0: !transform.any_op):
- }
-
- // expected-note @below {{other operation}}
- transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
- ^bb0(%arg0: !transform.any_op):
- }
-
- module attributes {transform.target_tag="payload"} {}
-}
-
-// -----
-
-module {
- // expected-error @below {{expected the transform entry point to be a top-level transform op}}
- func.func private @foo() attributes {transf...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81279
More information about the Mlir-commits
mailing list