[Mlir-commits] [mlir] [mlir][transform] Fix crash when op is erased during transform.foreach (PR #66357)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 03:36:25 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir
            
<details>
<summary>Changes</summary>
Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
--
Full diff: https://github.com/llvm/llvm-project/pull/66357.diff

5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+5) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+14) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+16-2) 
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+22) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+1-1) 


<pre>
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl&lt;MemoryEffects::EffectInstance&gt; &amp;effects);
 void onlyReadsPayload(SmallVectorImpl&lt;MemoryEffects::EffectInstance&gt; &amp;effects);
 
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
 /// Populates `consumedArguments` with positions of `block` arguments that are
 /// consumed by the operations in the `block`.
 void getConsumedBlockArguments(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+  auto iface = cast&lt;MemoryEffectOpInterface&gt;(transform.getOperation());
+  SmallVector&lt;MemoryEffects::EffectInstance&gt; effects;
+  iface.getEffects(effects);
+  return ::hasEffect&lt;MemoryEffects::Write, PayloadIRResource&gt;(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+  auto iface = cast&lt;MemoryEffectOpInterface&gt;(transform.getOperation());
+  SmallVector&lt;MemoryEffects::EffectInstance&gt; effects;
+  iface.getEffects(effects);
+  return ::hasEffect&lt;MemoryEffects::Read, PayloadIRResource&gt;(effects);
+}
+
 void transform::getConsumedBlockArguments(
     Block &amp;block, llvm::SmallDenseSet&lt;unsigned int&gt; &amp;consumedArguments) {
   SmallVector&lt;MemoryEffects::EffectInstance&gt; effects;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..0c9a80699aed09b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &amp;rewriter,
                             transform::TransformResults &amp;results,
                             transform::TransformState &amp;state) {
   SmallVector&lt;SmallVector&lt;Operation *&gt;&gt; resultOps(getNumResults(), {});
-
-  for (Operation *op : state.getPayloadOps(getTarget())) {
+  // Store payload ops in a vector because ops may be removed from the mapping
+  // by the TrackingRewriter while the iteration is in progress.
+  auto it = state.getPayloadOps(getTarget());
+  SmallVector&lt;Operation *&gt; targets(it.begin(), it.end());
+  for (Operation *op : targets) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl&lt;MemoryEffects::EffectInstance&gt; &amp;effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&amp;](Operation &amp;op) {
+
         return isHandleConsumed(iterVar, cast&lt;TransformOpInterface&gt;(&amp;op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
     onlyReadsHandle(getTarget(), effects);
   }
 
+  if (any_of(getBody().front().without_terminator(), [&amp;](Operation &amp;op) {
+        return doesModifyPayload(cast&lt;TransformOpInterface&gt;(&amp;op));
+      })) {
+    modifiesPayload(effects);
+  } else if (any_of(getBody().front().without_terminator(), [&amp;](Operation &amp;op) {
+               return doesReadPayload(cast&lt;TransformOpInterface&gt;(&amp;op));
+             })) {
+    onlyReadsPayload(effects);
+  }
+
   for (Value result : getResults())
     producesHandle(result, effects);
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
 
 // -----
 
+// CHECK-LABEL: func @consume_in_foreach()
+//  CHECK-NEXT:   return
+func.func @consume_in_foreach() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 2 : index
+  %3 = arith.constant 3 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %f = transform.structured.match ops{[&quot;arith.constant&quot;]} in %arg1 : (!transform.any_op) -&gt; !transform.any_op
+  transform.foreach %f : !transform.any_op {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-remark @below {{erasing}}
+    transform.test_emit_remark_and_erase_operand %arg2, &quot;erasing&quot; : !transform.any_op
+  }
+}
+
+// -----
+
 func.func @bar() {
   scf.execute_region {
     // expected-remark @below {{transform applied}}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &amp;results, transform::TransformState &amp;state) {
   emitRemark() &lt;&lt; getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
-    op-&gt;erase();
+    rewriter.eraseOp(op);
 
   if (getFailAfterErase())
     return emitSilenceableError() &lt;&lt; &quot;silenceable error&quot;;
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66357


More information about the Mlir-commits mailing list