[Mlir-commits] [mlir] 78c8ab5 - Revert "[mlir][transform] Improve error message of tracking listener. (#66987)"
Vitaly Buka
llvmlistbot at llvm.org
Mon Sep 25 09:07:31 PDT 2023
Author: Vitaly Buka
Date: 2023-09-25T09:07:17-07:00
New Revision: 78c8ab5844e618162c4cf3982d05102d4da10d23
URL: https://github.com/llvm/llvm-project/commit/78c8ab5844e618162c4cf3982d05102d4da10d23
DIFF: https://github.com/llvm/llvm-project/commit/78c8ab5844e618162c4cf3982d05102d4da10d23.diff
LOG: Revert "[mlir][transform] Improve error message of tracking listener. (#66987)"
Breaks https://lab.llvm.org/buildbot/#/builders/5/builds/36953
This reverts commit a7530452fd163c84e83e662b549ade7b0fae9edf.
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b45861e6190c18a..f87c0981bc4a9a1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -797,8 +797,7 @@ class TransformResults {
/// corresponds to the given list of payload IR ops. Each result must be set
/// by the transformation exactly once in case of transformation succeeding.
/// The value must have a type implementing TransformHandleTypeInterface.
- template <typename Range>
- void set(OpResult value, Range &&ops) {
+ template <typename Range> void set(OpResult value, Range &&ops) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(operations.size()) &&
"setting results for a non-existent handle");
@@ -973,9 +972,8 @@ class TrackingListener : public RewriterBase::Listener,
///
/// Derived classes may override `findReplacementOp` to specify custom
/// replacement rules.
- virtual DiagnosedSilenceableFailure
- findReplacementOp(Operation *&result, Operation *op,
- ValueRange newValues) const;
+ virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+ ValueRange newValues) const;
/// Notify the listener that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
@@ -987,9 +985,8 @@ class TrackingListener : public RewriterBase::Listener,
/// This function is called when a tracked payload op is dropped because no
/// replacement op was found. Derived classes can implement this function for
/// custom error handling.
- virtual void
- notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
- DiagnosedSilenceableFailure &&diag) {}
+ virtual void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) {}
/// Return the single op that defines all given values (if any).
static Operation *getCommonDefiningOp(ValueRange values);
@@ -1029,9 +1026,8 @@ class ErrorCheckingTrackingListener : public TrackingListener {
bool failed() const;
protected:
- void
- notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
- DiagnosedSilenceableFailure &&diag) override;
+ void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) override;
private:
/// The error state of this listener. "Success" indicates that no error
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 483b0e7f7a4f998..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1291,36 +1291,27 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
return defOp;
}
-DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
- Operation *&result, Operation *op, ValueRange newValues) const {
+FailureOr<Operation *>
+transform::TrackingListener::findReplacementOp(Operation *op,
+ ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
SmallVector<Value> values(newValues.begin(), newValues.end());
- DiagnosedSilenceableFailure diag = emitSilenceableFailure(
- getTransformOp(), "tracking listener failed to find replacement op "
- "during application of this transform op");
-
do {
// If the replacement values belong to
diff erent ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(values);
- if (!defOp) {
- diag.attachNote() << "replacement values belong to
diff erent ops";
- return diag;
- }
+ if (!defOp)
+ return failure();
// If the defining op has the same type, we take it as a replacement.
- if (op->getName() == defOp->getName()) {
- result = defOp;
- return DiagnosedSilenceableFailure::success();
- }
+ if (op->getName() == defOp->getName())
+ return defOp;
// Replacing an op with a constant-like equivalent is a common
// canonicalization.
- if (defOp->hasTrait<OpTrait::ConstantLike>()) {
- result = defOp;
- return DiagnosedSilenceableFailure::success();
- }
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
+ return defOp;
values.clear();
@@ -1328,22 +1319,17 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
if (auto findReplacementOpInterface =
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
values.assign(findReplacementOpInterface.getNextOperands());
- diag.attachNote(defOp->getLoc()) << "using operands provided by "
- "'FindPayloadReplacementOpInterface'";
continue;
}
// Skip through ops that implement CastOpInterface.
if (isa<CastOpInterface>(defOp)) {
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
- diag.attachNote(defOp->getLoc())
- << "using output of 'CastOpInterface' op";
continue;
}
} while (!values.empty());
- diag.attachNote() << "ran out of suitable replacement values";
- return diag;
+ return failure();
}
LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1412,39 +1398,32 @@ void transform::TrackingListener::notifyOperationReplaced(
};
// Helper function to check if the handle is alive.
- auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
+ auto hasAliveUser = [&]() {
for (Value v : opHandles) {
- for (OpOperand &use : v.getUses())
- if (use.getOwner() != transformOp &&
- !happensBefore(use.getOwner(), transformOp))
- return &use;
+ for (Operation *user : v.getUsers())
+ if (user != transformOp && !happensBefore(user, transformOp))
+ return true;
}
- return std::nullopt;
- }();
+ return false;
+ };
- if (!firstAliveUser.has_value() || handleWasConsumed()) {
+ if (!hasAliveUser() || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
return;
}
- Operation *replacement;
- DiagnosedSilenceableFailure diag =
- findReplacementOp(replacement, op, newValues);
+ FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
// If the op is tracked but no replacement op was found, send a
// notification.
- if (!diag.succeeded()) {
- diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
- << "replacement is required because alive handle(s) exist "
- << "(first use in this op as operand number "
- << (*firstAliveUser)->getOperandNumber() << ")";
- notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
+ if (failed(replacement)) {
+ notifyPayloadReplacementNotFound(op, newValues);
(void)replacePayloadOp(op, nullptr);
return;
}
- (void)replacePayloadOp(op, replacement);
+ (void)replacePayloadOp(op, *replacement);
}
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1467,20 +1446,17 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
}
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
- Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
-
- // Merge potentially existing diags and store the result in the listener.
- SmallVector<Diagnostic> diags;
- diag.takeDiagnostics(diags);
- if (!status.succeeded())
- status.takeDiagnostics(diags);
- status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
+ Operation *op, ValueRange values) {
+ if (status.succeeded()) {
+ status = emitSilenceableFailure(
+ getTransformOp(), "tracking listener failed to find replacement op");
+ }
- // Report more details.
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
for (auto &&[index, value] : llvm::enumerate(values))
status.attachNote(value.getLoc())
<< "[" << errorCounter << "] replacement value " << index;
+
++errorCounter;
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 2d57d4aa2547f2f..efbdab78d397faa 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -37,14 +37,12 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
- // expected-note @below {{ran out of suitable replacement values}}
+ // expected-error @below {{tracking listener failed to find replacement op}}
transform.apply_patterns to %0 {
transform.apply_patterns.transform.test_patterns
} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
- // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
transform.annotate %1 "annotated" : !transform.any_op
}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 366956c535cfa96..3c510c18996b0c2 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -32,8 +32,8 @@ struct TestTensorTransforms
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
- transform::TransformDialect>();
+ registry
+ .insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
}
StringRef getArgument() const final {
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
// Expose `findReplacementOp` as a public function, so that it can be tested.
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
- Operation *replacementOp;
- if (!findReplacementOp(replacementOp, op, newValues).succeeded())
+ FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
+ if (failed(replacementOp))
return nullptr;
- return replacementOp;
+ return *replacementOp;
}
};
} // namespace
@@ -352,17 +352,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
transform::TransformState transformState =
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
/*payloadRoot=*/nullptr);
- MLIRContext *context = rootOp->getContext();
- OpBuilder builder(context);
- auto transformOp = builder.create<transform::NamedSequenceOp>(
- rootOp->getLoc(),
- /*sym_name=*/"test_sequence",
- /*function_type=*/
- TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
- /*sym_visibility*/ StringAttr::get(context, "public"),
- /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
- /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
- DummyTrackingListener listener(transformState, transformOp);
+ DummyTrackingListener listener(transformState,
+ transform::TransformOpInterface());
Operation *replacement = listener.getReplacementOp(replaced, replacements);
if (!replacement) {
replaced->emitError("listener could not find replacement op");
More information about the Mlir-commits
mailing list