[Mlir-commits] [mlir] [mlir][transform] Fix crash when op is erased during transform.foreach (PR #66357)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 14 05:59:00 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66357:

>From 95ecc96561b54e91317566bc1e1dfa0ea7c39384 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Sep 2023 14:58:23 +0200
Subject: [PATCH] [mlir][transform] Fix crash when op is erased during
 transform.foreach

Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
---
 .../Transform/IR/TransformInterfaces.h        |  5 +++++
 .../Transform/IR/TransformInterfaces.cpp      | 14 ++++++++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++++++++++--
 .../Dialect/Transform/test-interpreter.mlir   | 22 +++++++++++++++++++
 .../TestTransformDialectExtension.cpp         |  2 +-
 5 files changed, 58 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
 /// Populates `consumedArguments` with positions of `block` arguments that are
 /// consumed by the operations in the `block`.
 void getConsumedBlockArguments(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
+}
+
 void transform::getConsumedBlockArguments(
     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
   SmallVector<MemoryEffects::EffectInstance> effects;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..a56adcfd7fd8472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-
-  for (Operation *op : state.getPayloadOps(getTarget())) {
+  // Store payload ops in a vector because ops may be removed from the mapping
+  // by the TrackingRewriter while the iteration is in progress.
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+  for (Operation *op : targets) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
     onlyReadsHandle(getTarget(), effects);
   }
 
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return doesModifyPayload(cast<TransformOpInterface>(&op));
+      })) {
+    modifiesPayload(effects);
+  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+               return doesReadPayload(cast<TransformOpInterface>(&op));
+             })) {
+    onlyReadsPayload(effects);
+  }
+
   for (Value result : getResults())
     producesHandle(result, effects);
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
 
 // -----
 
+// CHECK-LABEL: func @consume_in_foreach()
+//  CHECK-NEXT:   return
+func.func @consume_in_foreach() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 2 : index
+  %3 = arith.constant 3 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.foreach %f : !transform.any_op {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-remark @below {{erasing}}
+    transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
+  }
+}
+
+// -----
+
 func.func @bar() {
   scf.execute_region {
     // expected-remark @below {{transform applied}}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
-    op->erase();
+    rewriter.eraseOp(op);
 
   if (getFailAfterErase())
     return emitSilenceableError() << "silenceable error";



More information about the Mlir-commits mailing list