[Mlir-commits] [mlir] 29e1fd9 - [mlir][transform] Fix TrackingListener in regions that are isolated from above
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 26 09:08:38 PDT 2023
Author: Matthias Springer
Date: 2023-06-26T18:05:24+02:00
New Revision: 29e1fd9bdfcf798cb148141cc94a3ab6ab26cbdf
URL: https://github.com/llvm/llvm-project/commit/29e1fd9bdfcf798cb148141cc94a3ab6ab26cbdf
DIFF: https://github.com/llvm/llvm-project/commit/29e1fd9bdfcf798cb148141cc94a3ab6ab26cbdf.diff
LOG: [mlir][transform] Fix TrackingListener in regions that are isolated from above
When an operation is removed/replaced, the TrackingListener updates the internal transform state mapping between handles and payload IR. All handles must be updated, even the ones that are defined in a region that is beyond the most recent region that is isolated from above.
This fixes a bug, where a payload op was erased in a named sequence. Not only handles defined inside of the named region must be updated, but also all other handles such as the ones where the sequence is included.
Differential Revision: https://reviews.llvm.org/D153767
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 20f9b2122e933..8411eb0dd8412 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -222,14 +222,19 @@ class TransformState {
/// Populates `handles` with all handles pointing to the given Payload IR op.
/// Returns success if such handles exist, failure otherwise.
+ /// If `includeOutOfScope` is set to "true", handles that are defined in
+ /// regions beyond the most recent isolated from above region are included.
LogicalResult getHandlesForPayloadOp(Operation *op,
- SmallVectorImpl<Value> &handles) const;
+ SmallVectorImpl<Value> &handles,
+ bool includeOutOfScope = false) const;
/// Populates `handles` with all handles pointing to the given payload IR
/// value. Returns success if such handles exist, failure otherwise.
- LogicalResult
- getHandlesForPayloadValue(Value payloadValue,
- SmallVectorImpl<Value> &handles) const;
+ /// If `includeOutOfScope` is set to "true", handles that are defined in
+ /// regions beyond the most recent isolated from above region are included.
+ LogicalResult getHandlesForPayloadValue(Value payloadValue,
+ SmallVectorImpl<Value> &handles,
+ bool includeOutOfScope = false) const;
/// Applies the transformation specified by the given transform op and updates
/// the state accordingly.
@@ -410,42 +415,53 @@ class TransformState {
const TransformOptions &options = TransformOptions());
/// Returns the mappings frame for the region in which the value is defined.
- const Mappings &getMapping(Value value) const {
- return const_cast<TransformState *>(this)->getMapping(value);
+ /// If `allowOutOfScope` is set to "false", asserts that the value is in
+ /// scope, based on the current stack of frames.
+ const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
+ return const_cast<TransformState *>(this)->getMapping(value,
+ allowOutOfScope);
}
- Mappings &getMapping(Value value) {
+ Mappings &getMapping(Value value, bool allowOutOfScope = false) {
Region *region = value.getParentRegion();
auto it = mappings.find(region);
assert(it != mappings.end() &&
"trying to find a mapping for a value from an unmapped region");
#ifndef NDEBUG
- for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
- if (r == region)
- break;
- if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
- llvm_unreachable(
- "trying to get mapping beyond region that is isolated from above");
+ if (!allowOutOfScope) {
+ for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+ if (r == region)
+ break;
+ if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ llvm_unreachable("trying to get mapping beyond region that is "
+ "isolated from above");
+ }
}
#endif // NDEBUG
return it->second;
}
/// Returns the mappings frame for the region in which the operation resides.
- const Mappings &getMapping(Operation *operation) const {
- return const_cast<TransformState *>(this)->getMapping(operation);
+ /// If `allowOutOfScope` is set to "false", asserts that the operation is in
+ /// scope, based on the current stack of frames.
+ const Mappings &getMapping(Operation *operation,
+ bool allowOutOfScope = false) const {
+ return const_cast<TransformState *>(this)->getMapping(operation,
+ allowOutOfScope);
}
- Mappings &getMapping(Operation *operation) {
+ Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
Region *region = operation->getParentRegion();
auto it = mappings.find(region);
assert(it != mappings.end() &&
"trying to find a mapping for an operation from an unmapped region");
#ifndef NDEBUG
- for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
- if (r == region)
- break;
- if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
- llvm_unreachable(
- "trying to get mapping beyond region that is isolated from above");
+ if (!allowOutOfScope) {
+ for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+ if (r == region)
+ break;
+ if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ llvm_unreachable("trying to get mapping beyond region that is "
+ "isolated from above");
+ }
}
#endif // NDEBUG
return it->second;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5bd95169233b1..961cf34f0ee9a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -83,7 +83,8 @@ transform::TransformState::getPayloadValues(Value handleValue) const {
}
LogicalResult transform::TransformState::getHandlesForPayloadOp(
- Operation *op, SmallVectorImpl<Value> &handles) const {
+ Operation *op, SmallVectorImpl<Value> &handles,
+ bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping.reverse.find(op);
@@ -92,7 +93,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
found = true;
}
// Stop looking when reaching a region that is isolated from above.
- if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ if (!includeOutOfScope &&
+ region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
}
@@ -100,7 +102,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
}
LogicalResult transform::TransformState::getHandlesForPayloadValue(
- Value payloadValue, SmallVectorImpl<Value> &handles) const {
+ Value payloadValue, SmallVectorImpl<Value> &handles,
+ bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
auto iterator = mapping.reverseValues.find(payloadValue);
@@ -109,7 +112,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue(
found = true;
}
// Stop looking when reaching a region that is isolated from above.
- if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ if (!includeOutOfScope &&
+ region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
}
@@ -343,7 +347,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
#ifndef NDEBUG
for (Value opResult : op->getResults()) {
SmallVector<Value> valueHandles;
- (void)getHandlesForPayloadValue(opResult, valueHandles);
+ (void)getHandlesForPayloadValue(opResult, valueHandles,
+ /*includeOutOfScope=*/true);
assert(valueHandles.empty() && "expected no mapping to old results");
}
#endif // NDEBUG
@@ -351,10 +356,10 @@ transform::TransformState::replacePayloadOp(Operation *op,
// Drop the mapping between the op and all handles that point to it. Fail if
// there are no handles.
SmallVector<Value> opHandles;
- if (failed(getHandlesForPayloadOp(op, opHandles)))
+ if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
return failure();
for (Value handle : opHandles) {
- Mappings &mappings = getMapping(handle);
+ Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
dropMappingEntry(mappings.reverse, op, handle);
}
@@ -385,7 +390,7 @@ transform::TransformState::replacePayloadOp(Operation *op,
// element from an array invalidates iterators; merely changing the value of
// elements does not.
for (Value handle : opHandles) {
- Mappings &mappings = getMapping(handle);
+ Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
auto it = mappings.direct.find(handle);
if (it == mappings.direct.end())
continue;
@@ -410,11 +415,12 @@ transform::TransformState::replacePayloadOp(Operation *op,
LogicalResult
transform::TransformState::replacePayloadValue(Value value, Value replacement) {
SmallVector<Value> valueHandles;
- if (failed(getHandlesForPayloadValue(value, valueHandles)))
+ if (failed(getHandlesForPayloadValue(value, valueHandles,
+ /*includeOutOfScope=*/true)))
return failure();
for (Value handle : valueHandles) {
- Mappings &mappings = getMapping(handle);
+ Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
dropMappingEntry(mappings.reverseValues, value, handle);
// If replacing with null, that is erasing the mapping, drop the mapping
@@ -764,7 +770,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
- Mappings &mappings = getMapping(handle);
+ Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
llvm::erase_value(mappings.direct[handle], nullptr);
}
opHandlesToCompact.clear();
@@ -1346,7 +1352,8 @@ void transform::TrackingListener::notifyOperationReplaced(
// Replace op handle.
SmallVector<Value> opHandles;
- if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+ if (failed(getTransformState().getHandlesForPayloadOp(
+ op, opHandles, /*includeOutOfScope=*/true))) {
// Op is not tracked.
return;
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 062ec6e12e35e..992f78623a825 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -131,11 +131,47 @@ transform.sequence failures(propagate) {
transform.apply_patterns to %0 {
transform.apply_patterns.transform.test_patterns
} : !transform.any_op
+ // No marker should be printed.
transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
}
// -----
+// CHECK-LABEL: func @erase_tracked_op_in_named_sequence()
+// CHECK: "test.container"() ({
+// CHECK-NEXT: ^bb0:
+// CHECK-NEXT: }) : () -> ()
+module {
+ func.func @erase_tracked_op_in_named_sequence() {
+ "test.container"() ({
+ // expected-remark @below {{matched op}}
+ %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ }) : () -> ()
+ return
+ }
+
+ module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
+ transform.apply_patterns to %arg0 {
+ transform.apply_patterns.transform.test_patterns
+ } : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
+ include @foo failures(propagate) (%0) : (!transform.any_op) -> ()
+ // No marker should be printed.
+ transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
+ }
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalization(
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: return %[[c5]]
More information about the Mlir-commits
mailing list