[Mlir-commits] [mlir] [mlir][Transform] Relax the applicability of transform.foreach_match … (PR #70209)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 30 04:23:35 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/70209
>From ea565d899a549a4ee90d2a70e1acb8a664375186 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 25 Oct 2023 13:43:25 +0000
Subject: [PATCH] [mlir][Transform] Relax the applicability of
transform.foreach_match to also take into account the op itself
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 24 +++++++++++++++----
.../lib/Dialect/Transform/IR/TransformOps.cpp | 12 ++++++----
.../Dialect/Linalg/match-ops-interpreter.mlir | 3 ++-
3 files changed, 28 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b14c89eadb097d9..2fd0e80db96feba 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -481,8 +481,16 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
This operation consumes the operand and produces a new handle associated
with the same payload. This is necessary to trigger invalidation of handles
to any of the payload operations nested in the payload operations associated
- with the operand, as those are likely to be modified by actions. Note that
- the root payload operation associated with the operand are not matched.
+ with the operand, as those are likely to be modified by actions.
+
+ By default, the root payload operation associated with the operand is not
+ matched. This is to support the conservative case where applied actions may
+ invalidate the root payload operation. If the optional `restrict_root`
+ attribute is set, the root operand is guaranteed to not be invalidated by any
+ of the applied actions. In such cases, the root payload operation is also
+ matched. This is useful because matching the root payload operation is a
+ common idiom, when e.g. matching a func.func directly and operations nested
+ under it.
The operation succeeds if none of the matchers produced a definite failure
during application and if all of the applied actions produced success. Note
@@ -495,13 +503,19 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
}];
let arguments = (ins TransformHandleTypeInterface:$root,
+ UnitAttr:$restrict_root,
SymbolRefArrayAttr:$matchers,
SymbolRefArrayAttr:$actions);
let results = (outs TransformHandleTypeInterface:$updated);
- let assemblyFormat =
- "`in` $root custom<ForeachMatchSymbols>($matchers, $actions) "
- "attr-dict `:` functional-type($root, $updated)";
+ let assemblyFormat = [{
+ (`restrict_root` $restrict_root^)?
+ `in`
+ $root
+ custom<ForeachMatchSymbols>($matchers, $actions)
+ attr-dict
+ `:` functional-type($root, $updated)
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8db77b6059dd2e3..514a75b5d590469 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -850,8 +850,9 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
- // Skip over the root op itself so we don't invalidate it.
- if (op == root)
+ // If getRestrictRoot is not present, skip over the root op itself so we
+ // don't invalidate it.
+ if (!getRestrictRoot() && op == root)
return WalkResult::advance();
DEBUG_MATCHER({
@@ -1556,10 +1557,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
- DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+ DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
return DiagnosedSilenceableFailure::success();
}
- DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+ DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
return emitSilenceableError() << "operation is not empty";
}
@@ -1961,7 +1962,8 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute(getFunctionTypeAttrName(state.name),
- TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes)));
+ TypeAttr::get(FunctionType::get(builder.getContext(),
+ rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 9489aadac843d7b..c88945c8a5c60fd 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -100,12 +100,13 @@ module attributes { transform.with_named_sequence } {
}
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
- transform.foreach_match in %arg0
+ transform.foreach_match restrict_root in %arg0
@match_structured_suppress -> @do_nothing
: (!transform.any_op) -> !transform.any_op
transform.yield
}
+ // expected-remark @below {{other}}
func.func @payload() attributes { transform.target_tag = "start_here" } {
// expected-remark @below {{other}}
%D = arith.constant dense<1.0> : tensor<2x4xf32>
More information about the Mlir-commits
mailing list