[Mlir-commits] [mlir] [mlir][transform] Fix crash when consuming an op in a named sequence (PR #67437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 26 07:42:12 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
Fix a crash when consuming an op in a named sequence, when the same op is also mapped in the caller's mapping. Ops must be removed from *all* mappings during the "expensive checks". Otherwise, we may have dangling pointers in the mappings data structures, which interfere with the expensive checks.
---
Full diff: https://github.com/llvm/llvm-project/pull/67437.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+7-2)
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+9-6)
- (modified) mlir/test/Dialect/Transform/expensive-checks.mlir (+19)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b45861e6190c18a..b4523144b80c660 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -617,10 +617,15 @@ class TransformState {
/// Forgets the payload IR ops associated with the given transform IR value,
/// as well as any association between value handles and the results of said
/// payload IR op.
- void forgetMapping(Value opHandle, ValueRange origOpFlatResults);
+ ///
+ /// If `allowOutOfScope` is set to "false", asserts that the handle is in
+ /// scope, based on the current stack of frames.
+ void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
+ bool allowOutOfScope = false);
void forgetValueMapping(Value valueHandle,
- ArrayRef<Operation *> payloadOperations);
+ ArrayRef<Operation *> payloadOperations,
+ bool allowOutOfScope = false);
/// Replaces the given payload op with another op. If the replacement op is
/// null, removes the association of the payload op with its handle. Returns
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 483b0e7f7a4f998..4556119f9fae927 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -305,8 +305,9 @@ void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
}
void transform::TransformState::forgetMapping(Value opHandle,
- ValueRange origOpFlatResults) {
- Mappings &mappings = getMapping(opHandle);
+ ValueRange origOpFlatResults,
+ bool allowOutOfScope) {
+ Mappings &mappings = getMapping(opHandle, allowOutOfScope);
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
@@ -333,8 +334,9 @@ void transform::TransformState::forgetMapping(Value opHandle,
}
void transform::TransformState::forgetValueMapping(
- Value valueHandle, ArrayRef<Operation *> payloadOperations) {
- Mappings &mappings = getMapping(valueHandle);
+ Value valueHandle, ArrayRef<Operation *> payloadOperations,
+ bool allowOutOfScope) {
+ Mappings &mappings = getMapping(valueHandle, allowOutOfScope);
for (Value payloadValue : mappings.reverseValues[valueHandle])
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
mappings.values.erase(valueHandle);
@@ -1021,9 +1023,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// pre-generated error messages, so we do not need the association to
// still be there when the invalidated handle is accessed.
SmallVector<Value> handles;
- (void)getHandlesForPayloadOp(op, handles);
+ (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
for (Value handle : handles)
- forgetMapping(handle, /*origOpFlatResults=*/ValueRange());
+ forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
+ /*allowOutOfScope=*/true);
cachedNames.erase(op);
}
diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir
index fef857bcef029df..ee9a5af8055247e 100644
--- a/mlir/test/Dialect/Transform/expensive-checks.mlir
+++ b/mlir/test/Dialect/Transform/expensive-checks.mlir
@@ -410,3 +410,22 @@ transform.sequence failures(propagate) {
transform.yield
}
}
+
+// -----
+
+module @named_inclusion_and_consumption attributes { transform.with_named_sequence } {
+
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.consumed}) -> () {
+ // Consuming this handle removes the mapping from the current stack frame
+ // mapping and from the caller's stack frame mapping. (If this were not
+ // be the case, the "expensive checks" caching mechanism for op names
+ // would throw an error saying that an op is mapped but not in the cache.)
+ transform.test_consume_operand %arg0 : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/67437
More information about the Mlir-commits
mailing list