[Mlir-commits] [mlir] [mlir][transform] Improve error when merging of modules fails. (PR #69331)
Ingo Müller
llvmlistbot at llvm.org
Fri Oct 20 06:43:11 PDT 2023
https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/69331
>From a3b77a4914bbd0681db123ad64c57cbcc17c4d51 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 17 Oct 2023 13:36:09 +0000
Subject: [PATCH 1/2] [mlir][transform] Improve error when merging of modules
fails.
This resolved #69112.
---
.../Transform/Transforms/TransformInterpreterUtils.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 41feffffaf97b3f..e0bde020daf7c6f 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -176,8 +176,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(transform::detail::mergeSymbolsInto(
mergedParsedLibraries.get(), std::move(parsedLibrary))))
- return mergedParsedLibraries->emitError()
- << "failed to verify merged transform module";
+ return parsedLibrary->emitError()
+ << "failed to merge symbols into shared library module";
}
}
>From 3a4b1d17011077ba430998e5b368039bde2ed501 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 20 Oct 2023 13:42:56 +0000
Subject: [PATCH 2/2] Rework involved functions to return 'InFlightDiagnostic'.
---
.../Transforms/TransformInterpreterUtils.h | 4 +--
.../TransformInterpreterPassBase.cpp | 13 ++++---
.../Transforms/TransformInterpreterUtils.cpp | 34 +++++++++++--------
...erpreter-external-symbol-decl-invalid.mlir | 6 ++--
4 files changed, 32 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 3fc02267f26e9da..4b64df426698744 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
-LogicalResult mergeSymbolsInto(Operation *target,
- OwningOpRef<Operation *> other);
+InFlightDiagnostic mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other);
} // namespace detail
/// Standalone util to apply the named sequence `entryPoint` to the payload.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..ca06e06bb299d45 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -335,11 +335,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(detail::mergeSymbolsInto(
- SymbolTable::getNearestSymbolTable(transformRoot),
- transformLibraryModule->get()->clone())))
- return emitError(transformRoot->getLoc(),
- "failed to merge library symbols into transform root");
+ InFlightDiagnostic diag = detail::mergeSymbolsInto(
+ SymbolTable::getNearestSymbolTable(transformRoot),
+ transformLibraryModule->get()->clone());
+ if (failed(diag)) {
+ diag.attachNote(transformRoot->getLoc())
+ << "failed to merge library symbols into transform root";
+ return diag;
+ }
}
// Step 4
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index e0bde020daf7c6f..64560a4f3abb8bd 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -196,8 +196,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
- FunctionOpInterface func2) {
+static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
+ FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");
@@ -240,10 +240,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
assert(func1.isExternal());
func1->erase();
- return success();
+ return InFlightDiagnostic();
}
-LogicalResult
+InFlightDiagnostic
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -300,7 +300,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
- SymbolTable &otherSymbolTable) -> LogicalResult {
+ SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -312,19 +312,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
- return success();
+ return InFlightDiagnostic();
};
if (symbolOp.isPrivate()) {
- if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
- *otherSymbolTable)))
- return failure();
+ InFlightDiagnostic diag = renameToUnique(
+ symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
+ if (failed(diag))
+ return diag;
continue;
}
if (collidingOp.isPrivate()) {
- if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
- *symbolTable)))
- return failure();
+ InFlightDiagnostic diag = renameToUnique(
+ collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
+ if (failed(diag))
+ return diag;
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
@@ -393,8 +395,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
// Do the actual merging.
- if (failed(mergeInto(funcOp, collidingFuncOp))) {
- return failure();
+ {
+ InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
+ if (failed(diag))
+ return diag;
}
}
}
@@ -404,7 +408,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "failed to verify target op after merging symbols";
LLVM_DEBUG(DBGS() << "done merging ops\n");
- return success();
+ return InFlightDiagnostic();
}
LogicalResult transform::applyTransformNamedSequence(
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index c1bd071dc138d56..7f61f5afd9c9857 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
More information about the Mlir-commits
mailing list