[Mlir-commits] [mlir] 4f63252 - [mlir][transform] Fix crash when op is erased during transform.foreach (#66357)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 05:59:40 PDT 2023


Author: Matthias Springer
Date: 2023-09-14T14:59:36+02:00
New Revision: 4f63252d5d0acfc32633a36e18fff469a606e6ce

URL: https://github.com/llvm/llvm-project/commit/4f63252d5d0acfc32633a36e18fff469a606e6ce
DIFF: https://github.com/llvm/llvm-project/commit/4f63252d5d0acfc32633a36e18fff469a606e6ce.diff

LOG: [mlir][transform] Fix crash when op is erased during transform.foreach (#66357)

Fixes a crash when an op, that is mapped to handle that a
`transform.foreach` iterates over, was erased (through the
`TrackingRewriter`). Erasing an op removes it from all mappings and
invalidates iterators. This is already taken care of when an op is
iterating over payload ops in its `apply` method, but not when another
transform op is erasing a tracked payload op.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
 /// Populates `consumedArguments` with positions of `block` arguments that are
 /// consumed by the operations in the `block`.
 void getConsumedBlockArguments(

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
+}
+
 void transform::getConsumedBlockArguments(
     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
   SmallVector<MemoryEffects::EffectInstance> effects;

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..a56adcfd7fd8472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-
-  for (Operation *op : state.getPayloadOps(getTarget())) {
+  // Store payload ops in a vector because ops may be removed from the mapping
+  // by the TrackingRewriter while the iteration is in progress.
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+  for (Operation *op : targets) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
     onlyReadsHandle(getTarget(), effects);
   }
 
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return doesModifyPayload(cast<TransformOpInterface>(&op));
+      })) {
+    modifiesPayload(effects);
+  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+               return doesReadPayload(cast<TransformOpInterface>(&op));
+             })) {
+    onlyReadsPayload(effects);
+  }
+
   for (Value result : getResults())
     producesHandle(result, effects);
 }

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
 
 // -----
 
+// CHECK-LABEL: func @consume_in_foreach()
+//  CHECK-NEXT:   return
+func.func @consume_in_foreach() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 2 : index
+  %3 = arith.constant 3 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.foreach %f : !transform.any_op {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-remark @below {{erasing}}
+    transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
+  }
+}
+
+// -----
+
 func.func @bar() {
   scf.execute_region {
     // expected-remark @below {{transform applied}}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
-    op->erase();
+    rewriter.eraseOp(op);
 
   if (getFailAfterErase())
     return emitSilenceableError() << "silenceable error";


        


More information about the Mlir-commits mailing list