[Mlir-commits] [mlir] f07718b - [mlir][transform] Improve error when merging of modules fails. (#69331)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 24 07:39:57 PDT 2023
Author: Ingo Müller
Date: 2023-10-24T16:39:52+02:00
New Revision: f07718b708e95c1ba0608eab458e9b37eaa8bbe3
URL: https://github.com/llvm/llvm-project/commit/f07718b708e95c1ba0608eab458e9b37eaa8bbe3
DIFF: https://github.com/llvm/llvm-project/commit/f07718b708e95c1ba0608eab458e9b37eaa8bbe3.diff
LOG: [mlir][transform] Improve error when merging of modules fails. (#69331)
This resolved #69112.
Added:
Modified:
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 3828021d9543035..0cf87c8b3c21cba 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
-LogicalResult mergeSymbolsInto(Operation *target,
- OwningOpRef<Operation *> other);
+InFlightDiagnostic mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other);
} // namespace detail
/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 741456e7ebbfb86..e14461eaf124abc 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -337,11 +337,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(detail::mergeSymbolsInto(
- SymbolTable::getNearestSymbolTable(transformRoot),
- transformLibraryModule->get()->clone())))
- return emitError(transformRoot->getLoc(),
- "failed to merge library symbols into transform root");
+ InFlightDiagnostic diag = detail::mergeSymbolsInto(
+ SymbolTable::getNearestSymbolTable(transformRoot),
+ transformLibraryModule->get()->clone());
+ if (failed(diag)) {
+ diag.attachNote(transformRoot->getLoc())
+ << "failed to merge library symbols into transform root";
+ return diag;
+ }
}
// Step 4
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 2e32b7f71d10448..b6b6256223ed23d 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -177,8 +177,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(transform::detail::mergeSymbolsInto(
mergedParsedLibraries.get(), std::move(parsedLibrary))))
- return mergedParsedLibraries->emitError()
- << "failed to verify merged transform module";
+ return parsedLibrary->emitError()
+ << "failed to merge symbols into shared library module";
}
}
@@ -197,8 +197,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
- FunctionOpInterface func2) {
+static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
+ FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");
@@ -241,10 +241,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
assert(func1.isExternal());
func1->erase();
- return success();
+ return InFlightDiagnostic();
}
-LogicalResult
+InFlightDiagnostic
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -301,7 +301,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
- SymbolTable &otherSymbolTable) -> LogicalResult {
+ SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -313,19 +313,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
- return success();
+ return InFlightDiagnostic();
};
if (symbolOp.isPrivate()) {
- if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
- *otherSymbolTable)))
- return failure();
+ InFlightDiagnostic diag = renameToUnique(
+ symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
+ if (failed(diag))
+ return diag;
continue;
}
if (collidingOp.isPrivate()) {
- if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
- *symbolTable)))
- return failure();
+ InFlightDiagnostic diag = renameToUnique(
+ collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
+ if (failed(diag))
+ return diag;
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
@@ -394,8 +396,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
// Do the actual merging.
- if (failed(mergeInto(funcOp, collidingFuncOp))) {
- return failure();
+ {
+ InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
+ if (failed(diag))
+ return diag;
}
}
}
@@ -405,7 +409,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "failed to verify target op after merging symbols";
LLVM_DEBUG(DBGS() << "done merging ops\n");
- return success();
+ return InFlightDiagnostic();
}
LogicalResult transform::applyTransformNamedSequence(
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index 060dab334ed4388..f2d143bfb770ab9 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
- // expected-error @below {{failed to merge library symbols into transform root}}
+ // expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
More information about the Mlir-commits
mailing list