[Mlir-commits] [mlir] [mlir][transform] Fix crash when op is erased during transform.foreach (PR #66357)
Matthias Springer
llvmlistbot at llvm.org
Thu Sep 14 05:59:00 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66357:
>From 95ecc96561b54e91317566bc1e1dfa0ea7c39384 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Sep 2023 14:58:23 +0200
Subject: [PATCH] [mlir][transform] Fix crash when op is erased during
transform.foreach
Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
---
.../Transform/IR/TransformInterfaces.h | 5 +++++
.../Transform/IR/TransformInterfaces.cpp | 14 ++++++++++++
.../lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++++++++++--
.../Dialect/Transform/test-interpreter.mlir | 22 +++++++++++++++++++
.../TestTransformDialectExtension.cpp | 2 +-
5 files changed, 58 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
/// Populates `consumedArguments` with positions of `block` arguments that are
/// consumed by the operations in the `block`.
void getConsumedBlockArguments(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+ auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ iface.getEffects(effects);
+ return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+ auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ iface.getEffects(effects);
+ return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
+}
+
void transform::getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
SmallVector<MemoryEffects::EffectInstance> effects;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..a56adcfd7fd8472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-
- for (Operation *op : state.getPayloadOps(getTarget())) {
+ // Store payload ops in a vector because ops may be removed from the mapping
+ // by the TrackingRewriter while the iteration is in progress.
+ SmallVector<Operation *> targets =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+ for (Operation *op : targets) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
BlockArgument iterVar = getIterationVariable();
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+
return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
})) {
consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
onlyReadsHandle(getTarget(), effects);
}
+ if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+ return doesModifyPayload(cast<TransformOpInterface>(&op));
+ })) {
+ modifiesPayload(effects);
+ } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+ return doesReadPayload(cast<TransformOpInterface>(&op));
+ })) {
+ onlyReadsPayload(effects);
+ }
+
for (Value result : getResults())
producesHandle(result, effects);
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
// -----
+// CHECK-LABEL: func @consume_in_foreach()
+// CHECK-NEXT: return
+func.func @consume_in_foreach() {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 2 : index
+ %3 = arith.constant 3 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.foreach %f : !transform.any_op {
+ ^bb2(%arg2: !transform.any_op):
+ // expected-remark @below {{erasing}}
+ transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
+ }
+}
+
+// -----
+
func.func @bar() {
scf.execute_region {
// expected-remark @below {{transform applied}}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
emitRemark() << getRemark();
for (Operation *op : state.getPayloadOps(getTarget()))
- op->erase();
+ rewriter.eraseOp(op);
if (getFailAfterErase())
return emitSilenceableError() << "silenceable error";
More information about the Mlir-commits
mailing list