[Mlir-commits] [mlir] 95f495c - [mlir][transform] add a check for nested consumption in ApplyEachOpTrait
Alex Zinenko
llvmlistbot at llvm.org
Fri Jun 9 03:44:30 PDT 2023
Author: Alex Zinenko
Date: 2023-06-09T10:44:24Z
New Revision: 95f495c7b3a6bc2ed47bb6f9977256ec0f841e52
URL: https://github.com/llvm/llvm-project/commit/95f495c7b3a6bc2ed47bb6f9977256ec0f841e52
DIFF: https://github.com/llvm/llvm-project/commit/95f495c7b3a6bc2ed47bb6f9977256ec0f841e52.diff
LOG: [mlir][transform] add a check for nested consumption in ApplyEachOpTrait
ApplyEachOpTrait applies to payload ops associated with its operand
handle one-by-one in order. If a handle is consumed, this usually
indicates that the associated payload ops are erased or rewritten. Add a
check that we don't consume an ancestor payload operation before
consuming its descendant, as the latter is likely to be a dangling
pointer. Transform operations for which this is a legitimate behavior
(i.e., they consume the handle but don't actually erase or rewrite the
payload operation) should implement the interface directly and allow for
repeated handles.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D152510
Added:
mlir/test/Dialect/Transform/apply-foreach-nested.mlir
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 28972f1b59fe5..1afac0aba48dc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -187,6 +187,8 @@ class TransformState {
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
public:
+ const TransformOptions &getOptions() const { return options; }
+
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
@@ -1352,6 +1354,12 @@ applyTransformToEach(TransformOpTy transformOp, Range &&targets,
return DiagnosedSilenceableFailure::success();
}
+/// Reports an error and returns failure if `targets` contains an ancestor
+/// operation before its descendant (or a copy of itself). Implementation detail
+/// for expensive checks during `TransformEachOpTrait::apply`.
+LogicalResult checkNestedConsumption(Location loc,
+ ArrayRef<Operation *> targets);
+
} // namespace detail
} // namespace transform
} // namespace mlir
@@ -1360,7 +1368,18 @@ template <typename OpTy>
mlir::DiagnosedSilenceableFailure
mlir::transform::TransformEachOpTrait<OpTy>::apply(
TransformResults &transformResults, TransformState &state) {
- auto targets = state.getPayloadOps(this->getOperation()->getOperand(0));
+ Value handle = this->getOperation()->getOperand(0);
+ auto targets = state.getPayloadOps(handle);
+
+ // If the operand is consumed, check if it is associated with operations that
+ // may be erased before their nested operations are.
+ if (state.getOptions().getExpensiveChecksEnabled() &&
+ isHandleConsumed(handle, cast<transform::TransformOpInterface>(
+ this->getOperation())) &&
+ failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
+ llvm::to_vector(targets)))) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
// Step 1. Handle the corner case where no target is specified.
// This is typically the case when the matcher fails to apply and we need to
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 068fa0bec328e..a6c5629933294 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1330,6 +1330,28 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//
+LogicalResult
+transform::detail::checkNestedConsumption(Location loc,
+ ArrayRef<Operation *> targets) {
+ for (auto &&[position, parent] : llvm::enumerate(targets)) {
+ for (Operation *child : targets.drop_front(position + 1)) {
+ if (parent->isAncestor(child)) {
+ InFlightDiagnostic diag =
+ emitError(loc)
+ << "transform operation consumes a handle pointing to an ancestor "
+ "payload operation before its descendant";
+ diag.attachNote()
+ << "the ancestor is likely erased or rewritten before the "
+ "descendant is accessed, leading to undefined behavior";
+ diag.attachNote(parent->getLoc()) << "ancestor payload op";
+ diag.attachNote(child->getLoc()) << "descendant payload op";
+ return diag;
+ }
+ }
+ }
+ return success();
+}
+
LogicalResult
transform::detail::checkApplyToOne(Operation *transformOp,
Location payloadOpLoc,
diff --git a/mlir/test/Dialect/Transform/apply-foreach-nested.mlir b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir
new file mode 100644
index 0000000000000..d2f71644bbc51
--- /dev/null
+++ b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics \
+// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{enable-expensive-checks=1 bind-first-extra-to-ops=scf.for})"
+
+func.func private @bar()
+
+func.func @foo() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ // expected-note @below {{ancestor payload op}}
+ scf.for %i = %c0 to %c1 step %c10 {
+ // expected-note @below {{descendant payload op}}
+ scf.for %j = %c0 to %c1 step %c10 {
+ func.call @bar() : () -> ()
+ }
+ }
+ return
+}
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+ %1 = transform.test_reverse_payload_ops %arg1 : (!transform.op<"scf.for">) -> !transform.op<"scf.for">
+ // expected-error @below {{transform operation consumes a handle pointing to an ancestor payload operation before its descendant}}
+ // expected-note @below {{the ancestor is likely erased or rewritten before the descendant is accessed, leading to undefined behavior}}
+ transform.test_consume_operand_each %1 : !transform.op<"scf.for">
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ scf.for %i = %c0 to %c1 step %c10 {
+ scf.for %j = %c0 to %c1 step %c10 {
+ func.call @bar() : () -> ()
+ }
+ }
+ return
+}
+
+// No error here, processing ancestors before descendants.
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+ transform.test_consume_operand_each %arg1 : !transform.op<"scf.for">
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 85b0440277dc1..76111653dd922 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -108,6 +108,22 @@ def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
let cppNamespace = "::mlir::test";
}
+def TestConsumeOperandEach : Op<Transform_Dialect, "test_consume_operand_each",
+ [TransformOpInterface, TransformEachOpTrait,
+ MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+ let cppNamespace = "::mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state) {
+ return ::mlir::DiagnosedSilenceableFailure::success();
+ }
+ }];
+}
+
def TestConsumeOperandOfOpKindOrFail
: Op<Transform_Dialect, "test_consume_operand_of_op_kind_or_fail",
[DeclareOpInterfaceMethods<TransformOpInterface>,
More information about the Mlir-commits
mailing list