[Mlir-commits] [mlir] [mlir][transform] Fix crash when consuming an op in a named sequence (PR #67437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 07:42:12 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

Fix a crash when consuming an op in a named sequence, when the same op is also mapped in the caller's mapping. Ops must be removed from *all* mappings during the "expensive checks". Otherwise, we may have dangling pointers in the mappings data structures, which interfere with the expensive checks.

---
Full diff: https://github.com/llvm/llvm-project/pull/67437.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+7-2) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+9-6) 
- (modified) mlir/test/Dialect/Transform/expensive-checks.mlir (+19) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b45861e6190c18a..b4523144b80c660 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -617,10 +617,15 @@ class TransformState {
   /// Forgets the payload IR ops associated with the given transform IR value,
   /// as well as any association between value handles and the results of said
   /// payload IR op.
-  void forgetMapping(Value opHandle, ValueRange origOpFlatResults);
+  ///
+  /// If `allowOutOfScope` is set to "false", asserts that the handle is in
+  /// scope, based on the current stack of frames.
+  void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
+                     bool allowOutOfScope = false);
 
   void forgetValueMapping(Value valueHandle,
-                          ArrayRef<Operation *> payloadOperations);
+                          ArrayRef<Operation *> payloadOperations,
+                          bool allowOutOfScope = false);
 
   /// Replaces the given payload op with another op. If the replacement op is
   /// null, removes the association of the payload op with its handle. Returns
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 483b0e7f7a4f998..4556119f9fae927 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -305,8 +305,9 @@ void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
 }
 
 void transform::TransformState::forgetMapping(Value opHandle,
-                                              ValueRange origOpFlatResults) {
-  Mappings &mappings = getMapping(opHandle);
+                                              ValueRange origOpFlatResults,
+                                              bool allowOutOfScope) {
+  Mappings &mappings = getMapping(opHandle, allowOutOfScope);
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
@@ -333,8 +334,9 @@ void transform::TransformState::forgetMapping(Value opHandle,
 }
 
 void transform::TransformState::forgetValueMapping(
-    Value valueHandle, ArrayRef<Operation *> payloadOperations) {
-  Mappings &mappings = getMapping(valueHandle);
+    Value valueHandle, ArrayRef<Operation *> payloadOperations,
+    bool allowOutOfScope) {
+  Mappings &mappings = getMapping(valueHandle, allowOutOfScope);
   for (Value payloadValue : mappings.reverseValues[valueHandle])
     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
   mappings.values.erase(valueHandle);
@@ -1021,9 +1023,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
       // pre-generated error messages, so we do not need the association to
       // still be there when the invalidated handle is accessed.
       SmallVector<Value> handles;
-      (void)getHandlesForPayloadOp(op, handles);
+      (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
       for (Value handle : handles)
-        forgetMapping(handle, /*origOpFlatResults=*/ValueRange());
+        forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
+                      /*allowOutOfScope=*/true);
       cachedNames.erase(op);
     }
 
diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir
index fef857bcef029df..ee9a5af8055247e 100644
--- a/mlir/test/Dialect/Transform/expensive-checks.mlir
+++ b/mlir/test/Dialect/Transform/expensive-checks.mlir
@@ -410,3 +410,22 @@ transform.sequence failures(propagate) {
     transform.yield
   }
 }
+
+// -----
+
+module @named_inclusion_and_consumption attributes { transform.with_named_sequence } {
+
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.consumed}) -> () {
+    // Consuming this handle removes the mapping from the current stack frame
+    // mapping and from the caller's stack frame mapping. (If this were not
+    // be the case, the "expensive checks" caching mechanism for op names
+    // would throw an error saying that an op is mapped but not in the cache.)
+    transform.test_consume_operand %arg0 : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/67437


More information about the Mlir-commits mailing list