[Mlir-commits] [mlir] 07fef17 - [mlir][transform] Better debugging facilites for invalid API usage
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 11 23:49:37 PDT 2023
Author: Matthias Springer
Date: 2023-04-12T15:49:28+09:00
New Revision: 07fef178e8b0320dcfeaaef26b5a93eb7dfbf833
URL: https://github.com/llvm/llvm-project/commit/07fef178e8b0320dcfeaaef26b5a93eb7dfbf833
DIFF: https://github.com/llvm/llvm-project/commit/07fef178e8b0320dcfeaaef26b5a93eb7dfbf833.diff
LOG: [mlir][transform] Better debugging facilites for invalid API usage
This revision adds additional "expensive-checks" checks to the transform dialect that detect the most common cases of:
* Missing `consumesHandle` side effects on transform ops.
* Patterns that remove operations but do not notify the transform dialect.
In essence, these additional checks are looking for dangling pointers to erased payload ops in the transform dialect state and crash the program execution (by dereferencing free'd memory) or triggering an assertion failure. It is recommended to run these extra checks with ASAN. Otherwise, certain failures may not be detected. The ASAN error message can also be used to find the faulty transform op/pattern.
This change also fixes a few faulty transform ops.
Differential Revision: https://reviews.llvm.org/D147447
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 9f4e3d8e089ff..4dcc5dd3e471d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -563,6 +563,18 @@ class TransformState {
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
SmallVector<Region *> regionStack;
+
+ /// This cache stores operation names for operations that are tracked in the
+ /// transform dialect state. It is used to detect missing memory side effects
+ /// and op tracking.
+ ///
+ /// All tracked ops are added to this cache before a transform op is applied.
+ /// After the application of the transform op, the names of all tracked ops
+ /// are compared with the names in the cache. If there is a mismatch (or a
+ /// crash), op tracking is missing somewhere. This is typically a missing
+ /// "consumesHandle" side effect or a pattern that removes an op without
+ /// notifying a TrackingListener.
+ DenseMap<Operation *, OperationName> cachedNames;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d74a603277304..1f4ac08f1259e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1818,7 +1818,8 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
transform::TransformState &state) {
tensor::PadOp hoistedPadOp;
SmallVector<GenericOp> transposeOps;
- IRRewriter rewriter(target->getContext());
+ TrackingListener listener(state);
+ IRRewriter rewriter(target->getContext(), &listener);
FailureOr<Value> result =
hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
hoistedPadOp, transposeOps);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 0f49065a70161..41f1a3a18ec7d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -345,6 +345,15 @@ transform::TransformState::replacePayloadOp(Operation *op,
}
#endif // NDEBUG
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ if (options.getExpensiveChecksEnabled()) {
+ auto it = cachedNames.find(op);
+ assert(it != cachedNames.end() && "entry not found");
+ assert(it->second == op->getName() && "operation name mismatch");
+ cachedNames.erase(it);
+ }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
// TODO: consider invalidating the handles to nested objects here.
// If replacing with null, that is erasing the mapping, drop the mapping
@@ -357,6 +366,16 @@ transform::TransformState::replacePayloadOp(Operation *op,
return success();
}
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ if (options.getExpensiveChecksEnabled()) {
+ auto insertion = cachedNames.insert({replacement, replacement->getName()});
+ if (!insertion.second) {
+ assert(insertion.first->second == replacement->getName() &&
+ "operation is already cached with a
diff erent name");
+ }
+ }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
// Otherwise, replace the pointed-to object of all handles while preserving
// their relative order. First, replace the mapped operation if present.
for (Value handle : opHandles) {
@@ -722,6 +741,28 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
}
}
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ // Cache Operation* -> OperationName mappings. These will be checked after
+ // the transform has been applied to detect incorrect memory side effects
+ // and missing op tracking.
+ for (Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+ auto insertion = cachedNames.insert({op, op->getName()});
+ if (!insertion.second) {
+ if (insertion.first->second != op->getName()) {
+ // Operation is already in the cache, but with a
diff erent name.
+ DiagnosedDefiniteFailure diag =
+ emitDefiniteFailure(transform->getLoc())
+ << "expensive checks failure: operation mismatch, expected "
+ << insertion.first->second;
+ diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
+ return diag;
+ }
+ }
+ }
+ }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
// Find which operands are consumed.
@@ -748,11 +789,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// IR after that.
SmallVector<Value> origOpFlatResults;
SmallVector<Operation *> origAssociatedOps;
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ DenseSet<Operation *> consumedPayloadOps;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (unsigned index : consumedOperands) {
Value operand = transform->getOperand(index);
if (operand.getType().isa<TransformHandleTypeInterface>()) {
- for (Operation *payloadOp : getPayloadOps(operand))
+ for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ if (options.getExpensiveChecksEnabled()) {
+ // Store all consumed payload ops (and their nested ops) in a set for
+ // extra error checking.
+ payloadOp->walk(
+ [&](Operation *op) { consumedPayloadOps.insert(op); });
+ }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
continue;
}
if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
@@ -812,6 +865,61 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
}
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ if (options.getExpensiveChecksEnabled()) {
+ // Remove erased ops from the transform state.
+ for (Operation *op : consumedPayloadOps) {
+ // This payload op was consumed but it may still be mapped to one or
+ // multiple handles. Forget all handles that are mapped to the op, so that
+ // there are no dangling pointers in the transform dialect state. This is
+ // necessary so that the `cachedNames`-based checks work correctly.
+ //
+ // Note: Dangling pointers to erased payload ops are allowed if the
+ // corresponding handles are not used anymore. There is another
+ // "expensive-check" that looks for future uses of dangling payload op
+ // pointers (through arbitrary handles). Removing handles to erased ops
+ // does not interfere with the other expensive checks: handle invalidation
+ // happens earlier and keeps track of invalidated handles with
+ // pre-generated error messages, so we do not need the association to
+ // still be there when the invalidated handle is accessed.
+ SmallVector<Value> handles;
+ (void)getHandlesForPayloadOp(op, handles);
+ for (Value handle : handles)
+ forgetMapping(handle, /*origOpFlatResults=*/ValueRange());
+ cachedNames.erase(op);
+ }
+
+ // Check cached operation names.
+ for (Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+ // Make sure that the name of the op has not changed. If it has changed,
+ // the op was removed and a new op was allocated at the same memory
+ // location. This means that we are missing op tracking somewhere.
+ auto cacheIt = cachedNames.find(op);
+ if (cacheIt == cachedNames.end()) {
+ DiagnosedDefiniteFailure diag =
+ emitDefiniteFailure(transform->getLoc())
+ << "expensive checks failure: operation not found in cache";
+ diag.attachNote(op->getLoc()) << "payload op";
+ return diag;
+ }
+ // If the `getName` call (or the above `attachNote`) is crashing, we
+ // have a dangling pointer. This usually means that an op was erased but
+ // the transform dialect was not made aware of that; e.g., missing
+ // "consumesHandle" or rewriter usage.
+ if (cacheIt->second != op->getName()) {
+ DiagnosedDefiniteFailure diag =
+ emitDefiniteFailure(transform->getLoc())
+ << "expensive checks failure: operation mismatch, expected "
+ << cacheIt->second;
+ diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
+ return diag;
+ }
+ }
+ }
+ }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
for (OpResult result : transform->getResults()) {
assert(result.getDefiningOp() == transform.getOperation() &&
"payload IR association for a value other than the result of the "
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 95a004a7b049b..15583315cc9b3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -349,7 +349,8 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
if (!failed) {
// We will be using the clones, so cancel their scheduled deletion.
deleteClones.release();
- IRRewriter rewriter(getContext());
+ TrackingListener listener(state);
+ IRRewriter rewriter(getContext(), &listener);
for (const auto &kvp : llvm::zip(originals, clones)) {
Operation *original = std::get<0>(kvp);
Operation *clone = std::get<1>(kvp);
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 79b5256ef1e23..9803538e8b3b2 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -352,6 +352,12 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 7c4e02ce7e150..23184725723b3 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -203,7 +203,8 @@ def TestBranchingTransformOpTerminator
def TestEmitRemarkAndEraseOperandOp
: Op<Transform_Dialect, "test_emit_remark_and_erase_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
- MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ FunctionalStyleTransformOpTrait]> {
let arguments = (ins PDL_Operation:$target, StrAttr:$remark,
UnitAttr:$fail_after_erase);
let assemblyFormat = "$target `,` $remark attr-dict";
More information about the Mlir-commits
mailing list