[Mlir-commits] [mlir] [mlir][transform] Check for invalidated iterators on payload IR mappings (PR #66369)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 14 06:10:55 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66369:

>From 52551391c46e6da273766db34fe5cd4e96d28f1d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Sep 2023 15:09:53 +0200
Subject: [PATCH 1/2] [mlir][transform] Fix crash when op is erased during
 transform.foreach

Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
---
 .../Transform/IR/TransformInterfaces.h        |  5 +++++
 .../Transform/IR/TransformInterfaces.cpp      | 14 ++++++++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++++++++++--
 .../Dialect/Transform/test-interpreter.mlir   | 22 +++++++++++++++++++
 .../TestTransformDialectExtension.cpp         |  2 +-
 5 files changed, 58 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
 /// Populates `consumedArguments` with positions of `block` arguments that are
 /// consumed by the operations in the `block`.
 void getConsumedBlockArguments(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
+}
+
 void transform::getConsumedBlockArguments(
     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
   SmallVector<MemoryEffects::EffectInstance> effects;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..a56adcfd7fd8472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-
-  for (Operation *op : state.getPayloadOps(getTarget())) {
+  // Store payload ops in a vector because ops may be removed from the mapping
+  // by the TrackingRewriter while the iteration is in progress.
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+  for (Operation *op : targets) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
     onlyReadsHandle(getTarget(), effects);
   }
 
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return doesModifyPayload(cast<TransformOpInterface>(&op));
+      })) {
+    modifiesPayload(effects);
+  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+               return doesReadPayload(cast<TransformOpInterface>(&op));
+             })) {
+    onlyReadsPayload(effects);
+  }
+
   for (Value result : getResults())
     producesHandle(result, effects);
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
 
 // -----
 
+// CHECK-LABEL: func @consume_in_foreach()
+//  CHECK-NEXT:   return
+func.func @consume_in_foreach() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 2 : index
+  %3 = arith.constant 3 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.foreach %f : !transform.any_op {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-remark @below {{erasing}}
+    transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
+  }
+}
+
+// -----
+
 func.func @bar() {
   scf.execute_region {
     // expected-remark @below {{transform applied}}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
-    op->erase();
+    rewriter.eraseOp(op);
 
   if (getFailAfterErase())
     return emitSilenceableError() << "silenceable error";

>From 608c6d5eb252d871cb03115a8b0fd07bbdf735e8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Sep 2023 15:10:06 +0200
Subject: [PATCH 2/2] [mlir][transform] Check for invalidated iterators on
 payload IR mappings

Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp before dereferencing or incrementing.
---
 .../Transform/IR/TransformInterfaces.h        | 31 +++++++++++++++++--
 .../Transform/IR/TransformInterfaces.cpp      | 18 +++++++++++
 .../TestTransformDialectExtension.cpp         |  5 +--
 3 files changed, 50 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index efd8d573936c332..88c1250dd241571 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -170,6 +170,12 @@ class TransformState {
   /// should be emitted when the value is used.
   using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
 
+#ifndef NDEBUG
+  /// Debug only: A timestamp is associated with each transform IR value, so
+  /// that invalid iterator usage can be detected more reliably.
+  using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
+#endif // NDEBUG
+
   /// The bidirectional mappings between transform IR values and payload IR
   /// operations, and the mapping between transform IR values and parameters.
   struct Mappings {
@@ -178,6 +184,11 @@ class TransformState {
     ParamMapping params;
     ValueMapping values;
     ValueMapping reverseValues;
+
+#ifndef NDEBUG
+    TransformIRTimestampMapping timestamps;
+    void incrementTimestamp(Value value) { ++timestamps[value]; }
+#endif // NDEBUG
   };
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
@@ -207,10 +218,26 @@ class TransformState {
   /// not enumerated. This function is helpful for transformations that apply to
   /// a particular handle.
   auto getPayloadOps(Value value) const {
+    ArrayRef<Operation *> view = getPayloadOpsView(value);
+
+#ifndef NDEBUG
+    // Memorize the current timestamp and make sure that it has not changed
+    // when incrementing or dereferencing the iterator returned by this
+    // function. The timestamp is incremented when the "direct" mapping is
+    // resized; this would invalidate the iterator returned by this function.
+    int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
+#endif // NDEBUG
+
     // When ops are replaced/erased, they are replaced with nullptr (until
     // the data structure is compacted). Do not enumerate these ops.
-    return llvm::make_filter_range(getPayloadOpsView(value),
-                                   [](Operation *op) { return op != nullptr; });
+    return llvm::make_filter_range(view, [=](Operation *op) {
+#ifndef NDEBUG
+      bool sameTimestamp =
+          currentTimestamp == this->getMapping(value).timestamps.lookup(value);
+      assert(sameTimestamp && "iterator was invalidated during iteration");
+#endif // NDEBUG
+      return op != nullptr;
+    });
   }
 
   /// Returns the list of parameters that the given transform IR value
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 00450a1ff8f36cf..a091047c440de35 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
+#ifndef NDEBUG
+  // Payload IR is removed from the mapping. This invalidates the respective
+  // iterators.
+  mappings.incrementTimestamp(opHandle);
+#endif // NDEBUG
 
   for (Value opResult : origOpFlatResults) {
     SmallVector<Value> resultHandles;
@@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping(
       Mappings &localMappings = getMapping(opHandle);
       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
+
+#ifndef NDEBUG
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      localMappings.incrementTimestamp(opHandle);
+#endif // NDEBUG
     }
   }
 }
@@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
+#ifndef NDEBUG
+    if (llvm::find(mappings.direct[handle], nullptr) !=
+        mappings.direct[handle].end())
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(handle);
+#endif // NDEBUG
     llvm::erase_value(mappings.direct[handle], nullptr);
   }
   opHandlesToCompact.clear();
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 21f9ff5999a5ed5..3e5f1baac684d42 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -360,8 +360,9 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
 DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  auto payloadOps = state.getPayloadOps(getTarget());
-  auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
+  auto it = state.getPayloadOps(getTarget());
+  SmallVector<Operation *> reversedOps(it.begin(), it.end());
+  llvm::reverse(reversedOps);
   results.set(llvm::cast<OpResult>(getResult()), reversedOps);
   return DiagnosedSilenceableFailure::success();
 }



More information about the Mlir-commits mailing list