[Mlir-commits] [mlir] d8cab3f - [mlir][Transform] Fix dropReverseMapping early exit condition
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Oct 13 08:30:55 PDT 2022
Author: Nicolas Vasilache
Date: 2022-10-13T08:30:45-07:00
New Revision: d8cab3f407070c6d80396553ce024e17a0659b04
URL: https://github.com/llvm/llvm-project/commit/d8cab3f407070c6d80396553ce024e17a0659b04
DIFF: https://github.com/llvm/llvm-project/commit/d8cab3f407070c6d80396553ce024e17a0659b04.diff
LOG: [mlir][Transform] Fix dropReverseMapping early exit condition
Previously, the erasure would not trigger and result in surprising behavior.
Differential Revision: https://reviews.llvm.org/D135881
Added:
Modified:
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 2810444ea864d..9b85af35783e7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -92,7 +92,7 @@ transform::TransformState::setPayloadOps(Value value,
void transform::TransformState::dropReverseMapping(Mappings &mappings,
Operation *op, Value value) {
auto it = mappings.reverse.find(op);
- if (it != mappings.reverse.end())
+ if (it == mappings.reverse.end())
return;
llvm::erase_value(it->getSecond(), value);
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 735491a05aa68..c7d02d2bb9341 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -895,3 +895,28 @@ transform.with_pdl_patterns {
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
}
}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 : !pdl.operation failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
+ // here, the handles nested under are {%arg0, %arg1, %0}
+ // expected-remark @below {{3 handles nested under}}
+ transform.test_report_number_of_tracked_handles_nested_under %arg1
+ // expected-remark @below {{erased}}
+ transform.test_emit_remark_and_erase_operand %0, "erased"
+ // here, the handles nested under are only {%arg0, %arg1}
+ // expected-remark @below {{2 handles nested under}}
+ transform.test_report_number_of_tracked_handles_nested_under %arg1
+ }
+
+ pdl.pattern @some : benefit(1) {
+ %0 = pdl.operation "test.some_op"
+ pdl.rewrite %0 with "transform.dialect"
+ }
+}
+
+"test.some_op"() : () -> ()
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 6a39a2e2df629..b890af57f8d00 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -328,6 +328,26 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ int64_t count = 0;
+ for (Operation *op : state.getPayloadOps(getTarget())) {
+ op->walk([&](Operation *nested) {
+ SmallVector<Value> handles;
+ (void)state.getHandlesForPayloadOp(nested, handles);
+ count += handles.size();
+ });
+ }
+ emitRemark() << count << " handles nested under";
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index aaec014ac660a..9ca267565b3a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -253,4 +253,13 @@ def TestCopyPayloadOp
let assemblyFormat = "$handle attr-dict";
}
+def TestReportNumberOfTrackedHandlesNestedUnder
+ : Op<Transform_Dialect, "test_report_number_of_tracked_handles_nested_under",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins PDL_Operation:$target);
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list