[Mlir-commits] [mlir] [mlir][transform] Improve error when merging of modules fails. (PR #69331)

Ingo Müller llvmlistbot at llvm.org
Fri Oct 20 06:43:11 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/69331

>From a3b77a4914bbd0681db123ad64c57cbcc17c4d51 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 17 Oct 2023 13:36:09 +0000
Subject: [PATCH 1/2] [mlir][transform] Improve error when merging of modules
 fails.

This resolved #69112.
---
 .../Transform/Transforms/TransformInterpreterUtils.cpp        | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 41feffffaf97b3f..e0bde020daf7c6f 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -176,8 +176,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
     for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
       if (failed(transform::detail::mergeSymbolsInto(
               mergedParsedLibraries.get(), std::move(parsedLibrary))))
-        return mergedParsedLibraries->emitError()
-               << "failed to verify merged transform module";
+        return parsedLibrary->emitError()
+               << "failed to merge symbols into shared library module";
     }
   }
 

>From 3a4b1d17011077ba430998e5b368039bde2ed501 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 20 Oct 2023 13:42:56 +0000
Subject: [PATCH 2/2] Rework involved functions to return 'InFlightDiagnostic'.

---
 .../Transforms/TransformInterpreterUtils.h    |  4 +--
 .../TransformInterpreterPassBase.cpp          | 13 ++++---
 .../Transforms/TransformInterpreterUtils.cpp  | 34 +++++++++++--------
 ...erpreter-external-symbol-decl-invalid.mlir |  6 ++--
 4 files changed, 32 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 3fc02267f26e9da..4b64df426698744 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
 //       function to clone (or move) `other` in order to improve efficiency.
 //       This might primarily make sense if we can also prune the symbols that
 //       are merged to a subset (such as those that are actually used).
-LogicalResult mergeSymbolsInto(Operation *target,
-                               OwningOpRef<Operation *> other);
+InFlightDiagnostic mergeSymbolsInto(Operation *target,
+                                    OwningOpRef<Operation *> other);
 } // namespace detail
 
 /// Standalone util to apply the named sequence `entryPoint` to the payload.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..ca06e06bb299d45 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -335,11 +335,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(detail::mergeSymbolsInto(
-            SymbolTable::getNearestSymbolTable(transformRoot),
-            transformLibraryModule->get()->clone())))
-      return emitError(transformRoot->getLoc(),
-                       "failed to merge library symbols into transform root");
+    InFlightDiagnostic diag = detail::mergeSymbolsInto(
+        SymbolTable::getNearestSymbolTable(transformRoot),
+        transformLibraryModule->get()->clone());
+    if (failed(diag)) {
+      diag.attachNote(transformRoot->getLoc())
+          << "failed to merge library symbols into transform root";
+      return diag;
+    }
   }
 
   // Step 4
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index e0bde020daf7c6f..64560a4f3abb8bd 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -196,8 +196,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
 /// Merge `func1` into `func2`. The two ops must be inside the same parent op
 /// and mergable according to `canMergeInto`. The function erases `func1` such
 /// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
-                               FunctionOpInterface func2) {
+static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
+                                    FunctionOpInterface func2) {
   assert(canMergeInto(func1, func2));
   assert(func1->getParentOp() == func2->getParentOp() &&
          "expected func1 and func2 to be in the same parent op");
@@ -240,10 +240,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
   assert(func1.isExternal());
   func1->erase();
 
-  return success();
+  return InFlightDiagnostic();
 }
 
-LogicalResult
+InFlightDiagnostic
 transform::detail::mergeSymbolsInto(Operation *target,
                                     OwningOpRef<Operation *> other) {
   assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -300,7 +300,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
       auto renameToUnique =
           [&](SymbolOpInterface op, SymbolOpInterface otherOp,
               SymbolTable &symbolTable,
-              SymbolTable &otherSymbolTable) -> LogicalResult {
+              SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
         LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
         FailureOr<StringAttr> maybeNewName =
             symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -312,19 +312,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
         }
         LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
                           << "\n");
-        return success();
+        return InFlightDiagnostic();
       };
 
       if (symbolOp.isPrivate()) {
-        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
-                                  *otherSymbolTable)))
-          return failure();
+        InFlightDiagnostic diag = renameToUnique(
+            symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
+        if (failed(diag))
+          return diag;
         continue;
       }
       if (collidingOp.isPrivate()) {
-        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
-                                  *symbolTable)))
-          return failure();
+        InFlightDiagnostic diag = renameToUnique(
+            collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
+        if (failed(diag))
+          return diag;
         continue;
       }
       LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
@@ -393,8 +395,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
       assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
 
       // Do the actual merging.
-      if (failed(mergeInto(funcOp, collidingFuncOp))) {
-        return failure();
+      {
+        InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
+        if (failed(diag))
+          return diag;
       }
     }
   }
@@ -404,7 +408,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
            << "failed to verify target op after merging symbols";
 
   LLVM_DEBUG(DBGS() << "done merging ops\n");
-  return success();
+  return InFlightDiagnostic();
 }
 
 LogicalResult transform::applyTransformNamedSequence(
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index c1bd071dc138d56..7f61f5afd9c9857 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
   transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
     include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
   transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()



More information about the Mlir-commits mailing list