[Mlir-commits] [mlir] [mlir] propagate silenceable failures in transform.foreach_match (PR #86956)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Thu Mar 28 07:37:52 PDT 2024
https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/86956
The original implementation was eagerly reporting silenceable failures from actions as definite failures. Since silenceable failures are intended for cases when the IR has not been irreversibly modified, it's okay to propagate them as silenceable failures of the parent op.
Fixes #86834.
>From 4ec09a046c1ca3a2791f444f88af6f994897d95a Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 28 Mar 2024 14:33:26 +0000
Subject: [PATCH] [mlir] propagate silenceable failures in
transform.foreach_match
The original implementation was eagerly reporting silenceable failures
from actions as definite failures. Since silenceable failures are
intended for cases when the IR has not been irreversibly modified, it's
okay to propagate them as silenceable failures of the parent op.
Fixes #86834.
---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 17 +++-
.../test/Dialect/Transform/foreach-match.mlir | 80 +++++++++++++++++++
2 files changed, 95 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/Transform/foreach-match.mlir
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8d2ed8f6d73714..17bb661bddf3c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1020,6 +1020,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
}
+ DiagnosedSilenceableFailure overallDiag =
+ DiagnosedSilenceableFailure::success();
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1058,8 +1060,19 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
action.getFunctionBody().front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
- if (failed(result.checkAndReport()))
+ if (result.isDefiniteFailure())
return WalkResult::interrupt();
+ if (result.isSilenceableFailure()) {
+ if (overallDiag.succeeded()) {
+ overallDiag = emitSilenceableError() << "actions failed";
+ }
+ overallDiag.attachNote(action->getLoc())
+ << "failed action: " << result.getMessage();
+ overallDiag.attachNote(op->getLoc())
+ << "when applied to this matching payload";
+ (void)result.silence();
+ continue;
+ }
}
break;
}
@@ -1075,7 +1088,7 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
- return DiagnosedSilenceableFailure::success();
+ return overallDiag;
}
void transform::ForeachMatchOp::getEffects(
diff --git a/mlir/test/Dialect/Transform/foreach-match.mlir b/mlir/test/Dialect/Transform/foreach-match.mlir
new file mode 100644
index 00000000000000..206625ae0746be
--- /dev/null
+++ b/mlir/test/Dialect/Transform/foreach-match.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+// Silenceable diagnostics suppressed.
+module attributes { transform.with_named_sequence } {
+ func.func @test_loop_peeling_not_beneficial() {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 40 : index
+ %step = arith.constant 5 : index
+ scf.for %i = %lb to %ub step %step {
+ arith.addi %i, %i : index
+ }
+ return
+ }
+
+ transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
+ transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// Silenceable diagnostics propagated.
+module attributes { transform.with_named_sequence } {
+ func.func @test_loop_peeling_not_beneficial() {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 40 : index
+ %step = arith.constant 5 : index
+ // expected-note @below {{when applied to this matching payload}}
+ scf.for %i = %lb to %ub step %step {
+ arith.addi %i, %i : index
+ }
+ return
+ }
+
+ // expected-note @below {{failed to peel the last iteration}}
+ transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
+ transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @main_suppress(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{actions failed}}
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list