[Mlir-commits] [mlir] c580bd2 - [mlir][transform] fix handle invalidation check for reentrant regions
Alex Zinenko
llvmlistbot at llvm.org
Fri Jun 23 01:21:00 PDT 2023
Author: Alex Zinenko
Date: 2023-06-23T08:20:49Z
New Revision: c580bd261c14bde64fff19eaf5efbf7fc2ea7b4f
URL: https://github.com/llvm/llvm-project/commit/c580bd261c14bde64fff19eaf5efbf7fc2ea7b4f
DIFF: https://github.com/llvm/llvm-project/commit/c580bd261c14bde64fff19eaf5efbf7fc2ea7b4f.diff
LOG: [mlir][transform] fix handle invalidation check for reentrant regions
When exiting the scope of a region attached to a transform op, clean up
the handle invalidation checks assocaited with handles defined in this
region. Otherwise, these checks may trigger on the next entry to the
region while there is no incorrect usage.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D153545
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/expensive-checks.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5a22ae51d21dd..9b45c15777040 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -279,14 +279,7 @@ class TransformState {
/// Forgets the mapping from or to values defined in the associated
/// transform IR region, and restores the mapping that existed before
/// entering this scope.
- ~RegionScope() {
- state.mappings.erase(region);
- if (storedMappings.has_value())
- state.mappings.swap(*storedMappings);
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- state.regionStack.pop_back();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
- }
+ ~RegionScope();
private:
/// Tag structure for
diff erentiating the constructor for isolated regions.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 32637b3d4442f..013d0e29ef5ce 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1075,6 +1075,51 @@ transform::TransformState::Extension::replacePayloadValue(Value value,
return state.replacePayloadValue(value, replacement);
}
+//===----------------------------------------------------------------------===//
+// TransformState::RegionScope
+//===----------------------------------------------------------------------===//
+
+transform::TransformState::RegionScope::~RegionScope() {
+ // Remove handle invalidation notices as handles are going out of scope.
+ // The same region may be re-entered leading to incorrect invalidation
+ // errors.
+ for (Block &block : *region) {
+ for (Value handle : block.getArguments()) {
+ state.invalidatedHandles.erase(handle);
+ }
+ for (Operation &op : block) {
+ for (Value handle : op.getResults()) {
+ state.invalidatedHandles.erase(handle);
+ }
+ }
+ }
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Remember pointers to payload ops referenced by the handles going out of
+ // scope.
+ SmallVector<Operation *> referencedOps =
+ llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse));
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
+ state.mappings.erase(region);
+ if (storedMappings.has_value())
+ state.mappings.swap(*storedMappings);
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // If the last handle to a payload op has gone out of scope, we no longer
+ // need to store the cached name. Pointers may get reused, leading to
+ // incorrect associations in the cache.
+ for (Operation *op : referencedOps) {
+ SmallVector<Value> handles;
+ if (succeeded(state.getHandlesForPayloadOp(op, handles)))
+ continue;
+ state.cachedNames.erase(op);
+ }
+
+ state.regionStack.pop_back();
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+}
+
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir
index e35c1791da939..2e6318918c44e 100644
--- a/mlir/test/Dialect/Transform/expensive-checks.mlir
+++ b/mlir/test/Dialect/Transform/expensive-checks.mlir
@@ -364,3 +364,49 @@ module {
transform.test_consume_operand %0 { allow_repeated_handles } : !transform.any_op
}
}
+
+// -----
+
+// Re-entering the region should not trigger the consumption error from previous
+// execution of the region.
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ transform.test_re_enter_region {
+ %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ transform.test_consume_operand %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Re-entering the region should not trigger the consumption error from previous
+// execution of the region.
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ transform.test_re_enter_region %0 : !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.test_consume_operand %arg1 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Consuming the same handle repeatedly in the region should trigger an error.
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-note @below {{payload op}}
+ // expected-note @below {{handle to invalidated ops}}
+ %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+ transform.test_re_enter_region {
+ // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+ // expected-note @below {{invalidated by this transform op}}
+ transform.test_consume_operand %0 : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 4b0a8f0c197e9..51c0615932b61 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -831,6 +831,47 @@ void mlir::test::ApplyTestPatternsOp::populatePatterns(
patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
}
+void mlir::test::TestReEnterRegionOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getOperands(), effects);
+ transform::modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ SmallVector<SmallVector<transform::MappedValue>> mappings;
+ for (BlockArgument arg : getBody().front().getArguments()) {
+ mappings.emplace_back(llvm::to_vector(llvm::map_range(
+ state.getPayloadOps(getOperand(arg.getArgNumber())),
+ [](Operation *op) -> transform::MappedValue { return op; })));
+ }
+
+ for (int i = 0; i < 4; ++i) {
+ auto scope = state.make_region_scope(getBody());
+ for (BlockArgument arg : getBody().front().getArguments()) {
+ if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ for (Operation &op : getBody().front().without_terminator()) {
+ DiagnosedSilenceableFailure diag =
+ state.applyTransform(cast<transform::TransformOpInterface>(op));
+ if (!diag.succeeded())
+ return diag;
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult mlir::test::TestReEnterRegionOp::verify() {
+ if (getNumOperands() != getBody().front().getNumArguments()) {
+ return emitOpError() << "expects as many operands as block arguments";
+ }
+ return success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 6c0bef9a81ec6..594c32d165d43 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -553,4 +553,15 @@ def ApplyTestPatternsOp
let cppNamespace = "::mlir::test";
}
+def TestReEnterRegionOp
+ : Op<Transform_Dialect, "test_re_enter_region",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "($args^ `:` type($args))? attr-dict-with-keyword regions";
+ let cppNamespace = "::mlir::test";
+ let hasVerifier = 1;
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list