[Mlir-commits] [mlir] d8cab3f - [mlir][Transform] Fix dropReverseMapping early exit condition

Nicolas Vasilache llvmlistbot at llvm.org
Thu Oct 13 08:30:55 PDT 2022


Author: Nicolas Vasilache
Date: 2022-10-13T08:30:45-07:00
New Revision: d8cab3f407070c6d80396553ce024e17a0659b04

URL: https://github.com/llvm/llvm-project/commit/d8cab3f407070c6d80396553ce024e17a0659b04
DIFF: https://github.com/llvm/llvm-project/commit/d8cab3f407070c6d80396553ce024e17a0659b04.diff

LOG: [mlir][Transform] Fix dropReverseMapping early exit condition

Previously, the erasure would not trigger and result in surprising behavior.

Differential Revision: https://reviews.llvm.org/D135881

Added: 
    

Modified: 
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 2810444ea864d..9b85af35783e7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -92,7 +92,7 @@ transform::TransformState::setPayloadOps(Value value,
 void transform::TransformState::dropReverseMapping(Mappings &mappings,
                                                    Operation *op, Value value) {
   auto it = mappings.reverse.find(op);
-  if (it != mappings.reverse.end())
+  if (it == mappings.reverse.end())
     return;
 
   llvm::erase_value(it->getSecond(), value);

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 735491a05aa68..c7d02d2bb9341 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -895,3 +895,28 @@ transform.with_pdl_patterns {
     transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
   }
 }
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 : !pdl.operation failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
+    // here, the handles nested under are {%arg0, %arg1, %0}
+    // expected-remark @below {{3 handles nested under}}
+    transform.test_report_number_of_tracked_handles_nested_under %arg1
+    // expected-remark @below {{erased}}
+    transform.test_emit_remark_and_erase_operand %0, "erased"
+    // here, the handles nested under are only {%arg0, %arg1}
+    // expected-remark @below {{2 handles nested under}}
+    transform.test_report_number_of_tracked_handles_nested_under %arg1
+  }
+
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+}
+
+"test.some_op"() : () -> ()

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 6a39a2e2df629..b890af57f8d00 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -328,6 +328,26 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  int64_t count = 0;
+  for (Operation *op : state.getPayloadOps(getTarget())) {
+    op->walk([&](Operation *nested) {
+      SmallVector<Value> handles;
+      (void)state.getHandlesForPayloadOp(nested, handles);
+      count += handles.size();
+    });
+  }
+  emitRemark() << count << " handles nested under";
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index aaec014ac660a..9ca267565b3a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -253,4 +253,13 @@ def TestCopyPayloadOp
   let assemblyFormat = "$handle attr-dict";
 }
 
+def TestReportNumberOfTrackedHandlesNestedUnder
+  : Op<Transform_Dialect, "test_report_number_of_tracked_handles_nested_under",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins PDL_Operation:$target);
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list