[Mlir-commits] [mlir] b0bf7ff - [mlir] add utilites for DiagnosedSilenceableFailure
Alex Zinenko
llvmlistbot at llvm.org
Mon Oct 17 08:31:58 PDT 2022
Author: Alex Zinenko
Date: 2022-10-17T15:31:28Z
New Revision: b0bf7ffffc3a2a65a22da4afb13feb855baf2042
URL: https://github.com/llvm/llvm-project/commit/b0bf7ffffc3a2a65a22da4afb13feb855baf2042
DIFF: https://github.com/llvm/llvm-project/commit/b0bf7ffffc3a2a65a22da4afb13feb855baf2042.diff
LOG: [mlir] add utilites for DiagnosedSilenceableFailure
This class adds helper functions similar to `emitError` for the
DiagnosedSilenceableFailure class in both the silenceable and definite
failure cases. These helpers simplify the use of said class and make
tranfsorm op application code idiomatic.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D136072
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/transform-state-extension.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2c81985e3c127..b56a5dd36e77a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -181,6 +181,99 @@ class [[nodiscard]] DiagnosedSilenceableFailure {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
+class DiagnosedDefiniteFailure;
+
+DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+ const Twine &message = {});
+
+/// A compatibility class connecting `InFlightDiagnostic` to
+/// `DiagnosedSilenceableFailure` while providing an interface similar to the
+/// former. Implicitly convertible to `DiagnosticSilenceableFailure` in definite
+/// failure state and to `LogicalResult` failure. Reports the error on
+/// conversion or on destruction. Instances of this class can be created by
+/// `emitDefiniteFailure()`.
+class DiagnosedDefiniteFailure {
+ friend DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+ const Twine &message);
+
+public:
+ /// Only move-constructible because it carries an in-flight diagnostic.
+ DiagnosedDefiniteFailure(DiagnosedDefiniteFailure &&) = default;
+
+ /// Forward the message to the diagnostic.
+ template <typename T>
+ DiagnosedDefiniteFailure &operator<<(T &&value) & {
+ diag << std::forward<T>(value);
+ return *this;
+ }
+ template <typename T>
+ DiagnosedDefiniteFailure &&operator<<(T &&value) && {
+ return std::move(this->operator<<(std::forward<T>(value)));
+ }
+
+ /// Attaches a note to the error.
+ Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
+ return diag.attachNote(loc);
+ }
+
+ /// Implicit conversion to DiagnosedSilenceableFailure in the definite failure
+ /// state. Reports the error.
+ operator DiagnosedSilenceableFailure() {
+ diag.report();
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ /// Implicit conversion to LogicalResult in the failure state. Reports the
+ /// error.
+ operator LogicalResult() {
+ diag.report();
+ return failure();
+ }
+
+private:
+ /// Constructs a definite failure at the given location with the given
+ /// message.
+ explicit DiagnosedDefiniteFailure(Location loc, const Twine &message)
+ : diag(emitError(loc, message)) {}
+
+ /// Copy-construction and any assignment is disallowed to prevent repeated
+ /// error reporting.
+ DiagnosedDefiniteFailure(const DiagnosedDefiniteFailure &) = delete;
+ DiagnosedDefiniteFailure &
+ operator=(const DiagnosedDefiniteFailure &) = delete;
+ DiagnosedDefiniteFailure &operator=(DiagnosedDefiniteFailure &&) = delete;
+
+ /// The error message.
+ InFlightDiagnostic diag;
+};
+
+/// Emits a definite failure with the given message. The returned object allows
+/// for last-minute modification to the error message, such as attaching notes
+/// and completing the message. It will be reported when the object is
+/// destructed or converted.
+inline DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+ const Twine &message) {
+ return DiagnosedDefiniteFailure(loc, message);
+}
+inline DiagnosedDefiniteFailure emitDefiniteFailure(Operation *op,
+ const Twine &message = {}) {
+ return emitDefiniteFailure(op->getLoc(), message);
+}
+
+/// Emits a silenceable failure with the given message. A silenceable failure
+/// must be either suppressed or converted into a definite failure and reported
+/// to the user.
+inline DiagnosedSilenceableFailure
+emitSilenceableFailure(Location loc, const Twine &message = {}) {
+ Diagnostic diag(loc, DiagnosticSeverity::Error);
+ diag << message;
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+}
+inline DiagnosedSilenceableFailure
+emitSilenceableFailure(Operation *op, const Twine &message = {}) {
+ return emitSilenceableFailure(op->getLoc(), message);
+}
+
namespace transform {
class TransformOpInterface;
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 238825d1ff0c9..fe29f303a630a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -63,20 +63,29 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
}
/// Creates the silenceable failure object with a diagnostic located at the
- /// current operation.
- DiagnosedSilenceableFailure emitSilenceableError() {
- Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ /// current operation. Silenceable failure must be suppressed or reported
+ /// explicitly at some later time.
+ DiagnosedSilenceableFailure
+ emitSilenceableError(const ::llvm::Twine &message = {}) {
+ return ::mlir::emitSilenceableFailure($_op);
+ }
+
+ /// Creates the definite failure object with a diagnostic located at the
+ /// current operation. Definite failure will be reported when the object
+ /// is destroyed or converted.
+ DiagnosedDefiniteFailure
+ emitDefiniteFailure(const ::llvm::Twine &message = {}) {
+ return ::mlir::emitDefiniteFailure($_op, message);
}
/// Creates the default silenceable failure for a transform op that failed
/// to properly apply to a target.
DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
Operation *target) {
- Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
+ DiagnosedSilenceableFailure diag = emitSilenceableFailure($_op->getLoc());
diag << $_op->getName() << " failed to apply";
diag.attachNote(target->getLoc()) << "when applied to this op";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ return diag;
}
}];
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 92b0b5eb1042f..280dbf0c87d9b 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -324,8 +324,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
if (transformOp.has_value()) {
return transformOp->emitSilenceableError() << message;
}
- foreachThreadOp->emitError() << message;
- return DiagnosedSilenceableFailure::definiteFailure();
+ return emitDefiniteFailure(foreachThreadOp, message);
};
if (foreachThreadOp.getNumResults() > 0)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e47e8e51c6830..1030f0a9c00a3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -470,10 +470,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
}
ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
if (containingOps.size() != 1) {
- // Definite failure.
- return DiagnosedSilenceableFailure(
- this->emitOpError("requires exactly one containing_op handle (got ")
- << containingOps.size() << ")");
+ return emitDefiniteFailure()
+ << "requires exactly one containing_op handle (got "
+ << containingOps.size() << ")";
}
Operation *containingOp = containingOps.front();
@@ -925,11 +924,11 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
}
if (splitPoints.size() != payload.size()) {
- emitError() << "expected the dynamic split point handle to point to as "
- "many operations ("
- << splitPoints.size() << ") as the target handle ("
- << payload.size() << ")";
- return DiagnosedSilenceableFailure::definiteFailure();
+ return emitDefiniteFailure()
+ << "expected the dynamic split point handle to point to as "
+ "many operations ("
+ << splitPoints.size() << ") as the target handle ("
+ << payload.size() << ")";
}
} else {
splitPoints.resize(payload.size(),
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 6d32c40df8a1f..e632be0d004b7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -177,17 +177,16 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
for (Operation *original : originals) {
if (original->isAncestor(getOperation())) {
- InFlightDiagnostic diag =
- emitError() << "scope must not contain the transforms being applied";
+ auto diag = emitDefiniteFailure()
+ << "scope must not contain the transforms being applied";
diag.attachNote(original->getLoc()) << "scope";
- return DiagnosedSilenceableFailure::definiteFailure();
+ return diag;
}
if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- InFlightDiagnostic diag =
- emitError()
- << "only isolated-from-above ops can be alternative scopes";
+ auto diag = emitDefiniteFailure()
+ << "only isolated-from-above ops can be alternative scopes";
diag.attachNote(original->getLoc()) << "scope";
- return DiagnosedSilenceableFailure(std::move(diag));
+ return diag;
}
}
@@ -523,8 +522,8 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
for (Operation *root : state.getPayloadOps(getRoot())) {
if (failed(extension->findAllMatches(
getPatternName().getLeafReference().getValue(), root, targets))) {
- emitOpError() << "could not find pattern '" << getPatternName() << "'";
- return DiagnosedSilenceableFailure::definiteFailure();
+ emitDefiniteFailure()
+ << "could not find pattern '" << getPatternName() << "'";
}
}
results.set(getResult().cast<OpResult>(), targets);
diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index f7f678ef1c2e1..1f29684b35e3a 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -44,3 +44,13 @@ module {
test_check_if_test_extension_present %arg0
}
}
+
+// -----
+
+module {
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !pdl.operation):
+ // expected-error @below {{TestTransformStateExtension missing}}
+ test_remap_operand_to_self %arg0
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index b890af57f8d00..483d4feda7abc 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -188,10 +188,8 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply(
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
- if (!extension) {
- emitError() << "TestTransformStateExtension missing";
- return DiagnosedSilenceableFailure::definiteFailure();
- }
+ if (!extension)
+ return emitDefiniteFailure("TestTransformStateExtension missing");
if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
getOperation())))
More information about the Mlir-commits
mailing list