[Mlir-commits] [mlir] f07718b - [mlir][transform] Improve error when merging of modules fails. (#69331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 24 07:39:57 PDT 2023


Author: Ingo Müller
Date: 2023-10-24T16:39:52+02:00
New Revision: f07718b708e95c1ba0608eab458e9b37eaa8bbe3

URL: https://github.com/llvm/llvm-project/commit/f07718b708e95c1ba0608eab458e9b37eaa8bbe3
DIFF: https://github.com/llvm/llvm-project/commit/f07718b708e95c1ba0608eab458e9b37eaa8bbe3.diff

LOG: [mlir][transform] Improve error when merging of modules fails. (#69331)

This resolved #69112.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 3828021d9543035..0cf87c8b3c21cba 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
 //       function to clone (or move) `other` in order to improve efficiency.
 //       This might primarily make sense if we can also prune the symbols that
 //       are merged to a subset (such as those that are actually used).
-LogicalResult mergeSymbolsInto(Operation *target,
-                               OwningOpRef<Operation *> other);
+InFlightDiagnostic mergeSymbolsInto(Operation *target,
+                                    OwningOpRef<Operation *> other);
 } // namespace detail
 
 /// Standalone util to apply the named sequence `transformRoot` to `payload` IR.

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 741456e7ebbfb86..e14461eaf124abc 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -337,11 +337,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(detail::mergeSymbolsInto(
-            SymbolTable::getNearestSymbolTable(transformRoot),
-            transformLibraryModule->get()->clone())))
-      return emitError(transformRoot->getLoc(),
-                       "failed to merge library symbols into transform root");
+    InFlightDiagnostic diag = detail::mergeSymbolsInto(
+        SymbolTable::getNearestSymbolTable(transformRoot),
+        transformLibraryModule->get()->clone());
+    if (failed(diag)) {
+      diag.attachNote(transformRoot->getLoc())
+          << "failed to merge library symbols into transform root";
+      return diag;
+    }
   }
 
   // Step 4

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 2e32b7f71d10448..b6b6256223ed23d 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -177,8 +177,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
     for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
       if (failed(transform::detail::mergeSymbolsInto(
               mergedParsedLibraries.get(), std::move(parsedLibrary))))
-        return mergedParsedLibraries->emitError()
-               << "failed to verify merged transform module";
+        return parsedLibrary->emitError()
+               << "failed to merge symbols into shared library module";
     }
   }
 
@@ -197,8 +197,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
 /// Merge `func1` into `func2`. The two ops must be inside the same parent op
 /// and mergable according to `canMergeInto`. The function erases `func1` such
 /// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
-                               FunctionOpInterface func2) {
+static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
+                                    FunctionOpInterface func2) {
   assert(canMergeInto(func1, func2));
   assert(func1->getParentOp() == func2->getParentOp() &&
          "expected func1 and func2 to be in the same parent op");
@@ -241,10 +241,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
   assert(func1.isExternal());
   func1->erase();
 
-  return success();
+  return InFlightDiagnostic();
 }
 
-LogicalResult
+InFlightDiagnostic
 transform::detail::mergeSymbolsInto(Operation *target,
                                     OwningOpRef<Operation *> other) {
   assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -301,7 +301,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
       auto renameToUnique =
           [&](SymbolOpInterface op, SymbolOpInterface otherOp,
               SymbolTable &symbolTable,
-              SymbolTable &otherSymbolTable) -> LogicalResult {
+              SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
         LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
         FailureOr<StringAttr> maybeNewName =
             symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -313,19 +313,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
         }
         LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
                           << "\n");
-        return success();
+        return InFlightDiagnostic();
       };
 
       if (symbolOp.isPrivate()) {
-        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
-                                  *otherSymbolTable)))
-          return failure();
+        InFlightDiagnostic diag = renameToUnique(
+            symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
+        if (failed(diag))
+          return diag;
         continue;
       }
       if (collidingOp.isPrivate()) {
-        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
-                                  *symbolTable)))
-          return failure();
+        InFlightDiagnostic diag = renameToUnique(
+            collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
+        if (failed(diag))
+          return diag;
         continue;
       }
       LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
@@ -394,8 +396,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
       assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
 
       // Do the actual merging.
-      if (failed(mergeInto(funcOp, collidingFuncOp))) {
-        return failure();
+      {
+        InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
+        if (failed(diag))
+          return diag;
       }
     }
   }
@@ -405,7 +409,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
            << "failed to verify target op after merging symbols";
 
   LLVM_DEBUG(DBGS() << "done merging ops\n");
-  return success();
+  return InFlightDiagnostic();
 }
 
 LogicalResult transform::applyTransformNamedSequence(

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index 060dab334ed4388..f2d143bfb770ab9 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
   transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
     include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
   transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 
-  // expected-error @below {{failed to merge library symbols into transform root}}
+  // expected-note @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()


        


More information about the Mlir-commits mailing list