[Mlir-commits] [mlir] 95f495c - [mlir][transform] add a check for nested consumption in ApplyEachOpTrait

Alex Zinenko llvmlistbot at llvm.org
Fri Jun 9 03:44:30 PDT 2023


Author: Alex Zinenko
Date: 2023-06-09T10:44:24Z
New Revision: 95f495c7b3a6bc2ed47bb6f9977256ec0f841e52

URL: https://github.com/llvm/llvm-project/commit/95f495c7b3a6bc2ed47bb6f9977256ec0f841e52
DIFF: https://github.com/llvm/llvm-project/commit/95f495c7b3a6bc2ed47bb6f9977256ec0f841e52.diff

LOG: [mlir][transform] add a check for nested consumption in ApplyEachOpTrait

ApplyEachOpTrait applies to payload ops associated with its operand
handle one-by-one in order. If a handle is consumed, this usually
indicates that the associated payload ops are erased or rewritten. Add a
check that we don't consume an ancestor payload operation before
consuming its descendant, as the latter is likely to be a dangling
pointer. Transform operations for which this is a legitimate behavior
(i.e., they consume the handle but don't actually erase or rewrite the
payload operation) should implement the interface directly and allow for
repeated handles.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D152510

Added: 
    mlir/test/Dialect/Transform/apply-foreach-nested.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 28972f1b59fe5..1afac0aba48dc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -187,6 +187,8 @@ class TransformState {
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
 
 public:
+  const TransformOptions &getOptions() const { return options; }
+
   /// Returns the op at which the transformation state is rooted. This is
   /// typically helpful for transformations that apply globally.
   Operation *getTopLevel() const;
@@ -1352,6 +1354,12 @@ applyTransformToEach(TransformOpTy transformOp, Range &&targets,
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Reports an error and returns failure if `targets` contains an ancestor
+/// operation before its descendant (or a copy of itself). Implementation detail
+/// for expensive checks during `TransformEachOpTrait::apply`.
+LogicalResult checkNestedConsumption(Location loc,
+                                     ArrayRef<Operation *> targets);
+
 } // namespace detail
 } // namespace transform
 } // namespace mlir
@@ -1360,7 +1368,18 @@ template <typename OpTy>
 mlir::DiagnosedSilenceableFailure
 mlir::transform::TransformEachOpTrait<OpTy>::apply(
     TransformResults &transformResults, TransformState &state) {
-  auto targets = state.getPayloadOps(this->getOperation()->getOperand(0));
+  Value handle = this->getOperation()->getOperand(0);
+  auto targets = state.getPayloadOps(handle);
+
+  // If the operand is consumed, check if it is associated with operations that
+  // may be erased before their nested operations are.
+  if (state.getOptions().getExpensiveChecksEnabled() &&
+      isHandleConsumed(handle, cast<transform::TransformOpInterface>(
+                                   this->getOperation())) &&
+      failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
+                                            llvm::to_vector(targets)))) {
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
 
   // Step 1. Handle the corner case where no target is specified.
   // This is typically the case when the matcher fails to apply and we need to

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 068fa0bec328e..a6c5629933294 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1330,6 +1330,28 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
 // Utilities for TransformEachOpTrait.
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+transform::detail::checkNestedConsumption(Location loc,
+                                          ArrayRef<Operation *> targets) {
+  for (auto &&[position, parent] : llvm::enumerate(targets)) {
+    for (Operation *child : targets.drop_front(position + 1)) {
+      if (parent->isAncestor(child)) {
+        InFlightDiagnostic diag =
+            emitError(loc)
+            << "transform operation consumes a handle pointing to an ancestor "
+               "payload operation before its descendant";
+        diag.attachNote()
+            << "the ancestor is likely erased or rewritten before the "
+               "descendant is accessed, leading to undefined behavior";
+        diag.attachNote(parent->getLoc()) << "ancestor payload op";
+        diag.attachNote(child->getLoc()) << "descendant payload op";
+        return diag;
+      }
+    }
+  }
+  return success();
+}
+
 LogicalResult
 transform::detail::checkApplyToOne(Operation *transformOp,
                                    Location payloadOpLoc,

diff  --git a/mlir/test/Dialect/Transform/apply-foreach-nested.mlir b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir
new file mode 100644
index 0000000000000..d2f71644bbc51
--- /dev/null
+++ b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics \
+// RUN:          --pass-pipeline="builtin.module(test-transform-dialect-interpreter{enable-expensive-checks=1 bind-first-extra-to-ops=scf.for})"
+
+func.func private @bar()
+
+func.func @foo() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  // expected-note @below {{ancestor payload op}}
+  scf.for %i = %c0 to %c1 step %c10 {
+    // expected-note @below {{descendant payload op}}
+    scf.for %j = %c0 to %c1 step %c10 {
+      func.call @bar() : () -> ()
+    }
+  }
+  return
+}
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+  %1 = transform.test_reverse_payload_ops %arg1 : (!transform.op<"scf.for">) -> !transform.op<"scf.for">
+  // expected-error @below {{transform operation consumes a handle pointing to an ancestor payload operation before its descendant}}
+  // expected-note @below {{the ancestor is likely erased or rewritten before the descendant is accessed, leading to undefined behavior}}
+  transform.test_consume_operand_each %1 : !transform.op<"scf.for">
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  scf.for %i = %c0 to %c1 step %c10 {
+    scf.for %j = %c0 to %c1 step %c10 {
+      func.call @bar() : () -> ()
+    }
+  }
+  return
+}
+
+// No error here, processing ancestors before descendants.
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+  transform.test_consume_operand_each %arg1 : !transform.op<"scf.for">
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 85b0440277dc1..76111653dd922 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -108,6 +108,22 @@ def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
   let cppNamespace = "::mlir::test";
 }
 
+def TestConsumeOperandEach : Op<Transform_Dialect, "test_consume_operand_each",
+     [TransformOpInterface, TransformEachOpTrait,
+      MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state) {
+      return ::mlir::DiagnosedSilenceableFailure::success();
+    }
+  }];
+}
+
 def TestConsumeOperandOfOpKindOrFail
   : Op<Transform_Dialect, "test_consume_operand_of_op_kind_or_fail",
        [DeclareOpInterfaceMethods<TransformOpInterface>,


        


More information about the Mlir-commits mailing list