[Mlir-commits] [mlir] [mlir][transform] Handle multiple library preloading passes (PR #69320)

Ingo Müller llvmlistbot at llvm.org
Tue Oct 17 04:05:22 PDT 2023


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/69320

The transform dialect stores a "library module" that the preload pass
can populate. Until now, each pass registered an additional module by
simply pushing it to a vector; however, the interpreter only used the
first of them. This commit turns the registration into "loading", i.e.,
each newly added module gets merged into the existing one. This allows
the loading to be split into several passes, and using the library in
the interpreter now takes all of them into account. While this design
avoids repeated merging every time the library is accessed, it requires
that the implementation of merging modules lives in the
`TransformDialect` target (since it at the dialect depend on each
other).

This resolves https://github.com/llvm/llvm-project/issues/69111.

>From f5fb81de035f5f9596e838133691b9280bd147a5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 16 Oct 2023 12:38:47 +0000
Subject: [PATCH 1/2] Move transform interpreter utils to dialect.

This is a preparatory commit for fixing the handling of multiple
transform library modules. The plan is to merge every newly loaded
module into the global library module rather than keeping around a list.
Since that merging happens in the dialect and depends on the dialect,
the two depend on each other and thus have to live in the same CMake
target.
---
 .../Dialect/Transform/IR/TransformDialect.h   |  56 +++
 .../Transforms/TransformInterpreterUtils.h    |  49 ---
 .../Dialect/Transform/IR/TransformDialect.cpp | 366 ++++++++++++++++++
 .../Transforms/PreloadLibraryPass.cpp         |   1 -
 .../TransformInterpreterPassBase.cpp          |   1 -
 .../Transforms/TransformInterpreterUtils.cpp  | 349 -----------------
 6 files changed, 422 insertions(+), 400 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index db27f2c6fc49b75..b696438eb3532c4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -315,6 +315,62 @@ class BuildOnly : public DerivedTy {
   BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
 };
 
+namespace detail {
+
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
+                                     MLIRContext *context,
+                                     SmallVectorImpl<std::string> &fileNames);
+
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformModuleFromFile(MLIRContext *context,
+                             llvm::StringRef transformFileName,
+                             OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to parse, verify, aggregate and link the content of all mlir files
+/// nested under `transformLibraryPaths` and containing transform dialect
+/// specifications.
+LogicalResult
+assembleTransformLibraryFromPaths(MLIRContext *context,
+                                  ArrayRef<std::string> transformLibraryPaths,
+                                  OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where explicit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+/// Returns null if the module is not found.
+ModuleOp getPreloadedTransformModule(MLIRContext *context);
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+// TODO: Reconsider cloning individual ops rather than forcing users of the
+//       function to clone (or move) `other` in order to improve efficiency.
+//       This might primarily make sense if we can also prune the symbols that
+//       are merged to a subset (such as those that are actually used).
+LogicalResult mergeSymbolsInto(Operation *target,
+                               OwningOpRef<Operation *> other);
+
+/// Merge all symbols from `others` into `target`. See overload of
+/// `mergeSymbolsInto` on one `other` op for details.
+LogicalResult
+mergeSymbolsInto(Operation *target,
+                 MutableArrayRef<OwningOpRef<Operation *>> others);
+} // namespace detail
 } // namespace transform
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 3fc02267f26e9da..6ca6b8aa5e79e17 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -26,40 +26,6 @@ class Region;
 
 namespace transform {
 namespace detail {
-
-/// Expands the given list of `paths` to a list of `.mlir` files.
-///
-/// Each entry in `paths` may either be a regular file, in which case it ends up
-/// in the result list, or a directory, in which case all (regular) `.mlir`
-/// files in that directory are added. Any other file types lead to a failure.
-LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
-                                     MLIRContext *context,
-                                     SmallVectorImpl<std::string> &fileNames);
-
-/// Utility to parse and verify the content of a `transformFileName` MLIR file
-/// containing a transform dialect specification.
-LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
-                             llvm::StringRef transformFileName,
-                             OwningOpRef<ModuleOp> &transformModule);
-
-/// Utility to parse, verify, aggregate and link the content of all mlir files
-/// nested under `transformLibraryPaths` and containing transform dialect
-/// specifications.
-LogicalResult
-assembleTransformLibraryFromPaths(MLIRContext *context,
-                                  ArrayRef<std::string> transformLibraryPaths,
-                                  OwningOpRef<ModuleOp> &transformModule);
-
-/// Utility to load a transform interpreter `module` from a module that has
-/// already been preloaded in the context.
-/// This mode is useful in cases where explicit parsing of a transform library
-/// from file is expected to be prohibitively expensive.
-/// In such cases, the transform module is expected to be found in the preloaded
-/// library modules of the transform dialect.
-/// Returns null if the module is not found.
-ModuleOp getPreloadedTransformModule(MLIRContext *context);
-
 /// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
 /// that is either:
 ///   1. nested under `root` (takes precedence).
@@ -68,21 +34,6 @@ ModuleOp getPreloadedTransformModule(MLIRContext *context);
 TransformOpInterface findTransformEntryPoint(
     Operation *root, ModuleOp module,
     StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
-
-/// Merge all symbols from `other` into `target`. Both ops need to implement the
-/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
-/// modified by this function and might not verify after the function returns.
-/// Upon merging, private symbols may be renamed in order to avoid collisions in
-/// the result. Public symbols may not collide, with the exception of
-/// instances of `SymbolOpInterface`, where collisions are allowed if at least
-/// one of the two is external, in which case the other op preserved (or any one
-/// of the two if both are external).
-// TODO: Reconsider cloning individual ops rather than forcing users of the
-//       function to clone (or move) `other` in order to improve efficiency.
-//       This might primarily make sense if we can also prune the symbols that
-//       are merged to a subset (such as those that are actually used).
-LogicalResult mergeSymbolsInto(Operation *target,
-                               OwningOpRef<Operation *> other);
 } // namespace detail
 
 /// Standalone util to apply the named sequence `entryPoint` to the payload.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 32c56e903268f74..1f2f548f8e471ce 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -12,7 +12,15 @@
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/SCCIterator.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/SourceMgr.h"
+
+#define DEBUG_TYPE "transform-dialect"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 using namespace mlir;
 
@@ -172,3 +180,361 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
   return emitError(op->getLoc())
          << "unknown attribute: " << attribute.getName();
 }
+
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult transform::detail::expandPathsToMLIRFiles(
+    ArrayRef<std::string> &paths, MLIRContext *context,
+    SmallVectorImpl<std::string> &fileNames) {
+  for (const std::string &path : paths) {
+    auto loc = FileLineColLoc::get(context, path, 0, 0);
+
+    if (llvm::sys::fs::is_regular_file(path)) {
+      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
+      fileNames.push_back(path);
+      continue;
+    }
+
+    if (!llvm::sys::fs::is_directory(path)) {
+      return emitError(loc)
+             << "'" << path << "' is neither a file nor a directory";
+    }
+
+    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
+
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
+                          << "'\n");
+        continue;
+      }
+
+      if (!StringRef(fileName).endswith(".mlir")) {
+        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
+                          << "' because it does not end with '.mlir'\n");
+        continue;
+      }
+
+      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
+      fileNames.push_back(fileName);
+    }
+
+    if (ec)
+      return emitError(loc) << "error while opening files in '" << path
+                            << "': " << ec.message();
+  }
+
+  return success();
+}
+
+LogicalResult transform::detail::parseTransformModuleFromFile(
+    MLIRContext *context, llvm::StringRef transformFileName,
+    OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to open transform file: " << errorMessage;
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
+  auto preloadedLibraryRange =
+      context->getOrLoadDialect<transform::TransformDialect>()
+          ->getLibraryModules();
+  if (!preloadedLibraryRange.empty())
+    return *preloadedLibraryRange.begin();
+  return ModuleOp();
+}
+
+LogicalResult transform::detail::assembleTransformLibraryFromPaths(
+    MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
+    OwningOpRef<ModuleOp> &transformModule) {
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
+                                            libraryFileNames)))
+    return failure();
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    if (failed(transform::detail::parseTransformModuleFromFile(
+            context, libraryFileName, parsedLibrary)))
+      return failure();
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Merge parsed libraries into one module.
+  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(loc, "__transform");
+  {
+    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                         UnitAttr::get(context));
+    IRRewriter rewriter(context);
+    // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
+    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+      if (failed(transform::detail::mergeSymbolsInto(
+              mergedParsedLibraries.get(), std::move(parsedLibrary))))
+        return mergedParsedLibraries->emitError()
+               << "failed to verify merged transform module";
+    }
+  }
+
+  transformModule = std::move(mergedParsedLibraries);
+  return success();
+}
+
+/// Return whether `func1` can be merged into `func2`. For that to work
+/// `func1` has to be a declaration (aka has to be external) and `func2`
+/// either has to be a declaration as well, or it has to be public (otherwise,
+/// it wouldn't be visible by `func1`).
+static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+static LogicalResult mergeInto(FunctionOpInterface func1,
+                               FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
+
+  // Check and merge argument attributes.
+  MLIRContext *context = func1->getContext();
+  auto *td = context->getLoadedDialect<transform::TransformDialect>();
+  StringAttr consumedName = td->getConsumedAttrName();
+  StringAttr readOnlyName = td->getReadOnlyAttrName();
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
+      continue;
+    }
+
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
+    }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
+
+  return success();
+}
+
+LogicalResult
+transform::detail::mergeSymbolsInto(Operation *target,
+                                    OwningOpRef<Operation *> other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
+
+  // Step 1:
+  //
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  // TODO: Do we *actually* need to test in both directions?
+  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *, 2>{&otherSymbolTable,
+                                         &targetSymbolTable})) {
+    Operation *symbolTableOp = symbolTable->getOp();
+    for (Operation &op : symbolTableOp->getRegion(0).front()) {
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
+      StringAttr name = symbolOp.getNameAttr();
+      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
+
+      // Check if there is a colliding op in the other module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+      if (!collidingOp)
+        continue;
+
+      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
+
+      // Collisions are fine if both opt are functions and can be merged.
+      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+          collidingFuncOp =
+              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+          funcOp && collidingFuncOp) {
+        if (canMergeInto(funcOp, collidingFuncOp) ||
+            canMergeInto(collidingFuncOp, funcOp)) {
+          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+                                     "will be merged\n");
+          continue;
+        }
+
+        // If they can't be merged, proceed like any other collision.
+        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+      }
+
+      // Collision can be resolved by renaming if one of the ops is private.
+      auto renameToUnique =
+          [&](SymbolOpInterface op, SymbolOpInterface otherOp,
+              SymbolTable &symbolTable,
+              SymbolTable &otherSymbolTable) -> LogicalResult {
+        LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+        FailureOr<StringAttr> maybeNewName =
+            symbolTable.renameToUnique(op, {&otherSymbolTable});
+        if (failed(maybeNewName)) {
+          InFlightDiagnostic diag = op->emitError("failed to rename symbol");
+          diag.attachNote(otherOp->getLoc())
+              << "attempted renaming due to collision with this op";
+          return diag;
+        }
+        LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
+                          << "\n");
+        return success();
+      };
+
+      if (symbolOp.isPrivate()) {
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable)))
+          return failure();
+        continue;
+      }
+      if (collidingOp.isPrivate()) {
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable)))
+          return failure();
+        continue;
+      }
+      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+      InFlightDiagnostic diag = symbolOp.emitError()
+                                << "doubly defined symbol @" << name.getValue();
+      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+      return diag;
+    }
+  }
+
+  // TODO: This duplicates pass infrastructure. We should split this pass into
+  //       several and let the pass infrastructure do the verification.
+  for (auto *op : SmallVector<Operation *>{target, *other}) {
+    if (failed(mlir::verify(op)))
+      return op->emitError() << "failed to verify input op after renaming";
+  }
+
+  // Step 2:
+  //
+  // Move all ops from `other` into target and merge public symbols.
+  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+  {
+    SmallVector<SymbolOpInterface> opsToMove;
+    for (Operation &op : other->getRegion(0).front()) {
+      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+        opsToMove.push_back(symbol);
+    }
+
+    for (SymbolOpInterface op : opsToMove) {
+      // Remember potentially colliding op in the target module.
+      auto collidingOp = cast_or_null<SymbolOpInterface>(
+          targetSymbolTable.lookup(op.getNameAttr()));
+
+      // Move op even if we get a collision.
+      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
+      op->moveBefore(&target->getRegion(0).front(),
+                     target->getRegion(0).front().end());
+
+      // If there is no collision, we are done.
+      if (!collidingOp) {
+        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+        continue;
+      }
+
+      // The two colliding ops must both be functions because we have already
+      // emitted errors otherwise earlier.
+      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+      auto collidingFuncOp =
+          cast<FunctionOpInterface>(collidingOp.getOperation());
+
+      // Both ops are in the target module now and can be treated
+      // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
+      // `collidingFuncOp`.
+      if (!canMergeInto(funcOp, collidingFuncOp)) {
+        std::swap(funcOp, collidingFuncOp);
+      }
+      assert(canMergeInto(funcOp, collidingFuncOp));
+
+      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+                              << collidingFuncOp.getLoc() << ":\n"
+                              << collidingFuncOp << "\n");
+
+      // Update symbol table. This works with or without the previous `swap`.
+      targetSymbolTable.remove(funcOp);
+      targetSymbolTable.insert(collidingFuncOp);
+      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+      // Do the actual merging.
+      if (failed(mergeInto(funcOp, collidingFuncOp))) {
+        return failure();
+      }
+    }
+  }
+
+  if (failed(mlir::verify(target)))
+    return target->emitError()
+           << "failed to verify target op after merging symbols";
+
+  LLVM_DEBUG(DBGS() << "done merging ops\n");
+  return success();
+}
+
+LogicalResult transform::detail::mergeSymbolsInto(
+    Operation *target, MutableArrayRef<OwningOpRef<Operation *>> others) {
+  for (OwningOpRef<Operation *> &other : others) {
+    if (failed(transform::detail::mergeSymbolsInto(target, std::move(other))))
+      return target->emitError() << "failed to verify merged transform module";
+  }
+  return success();
+}
diff --git a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
index d2e7108c0288623..2075f968b2eb8b8 100644
--- a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Transforms/Passes.h"
-#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 
 using namespace mlir;
 
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..4d1d9d38dcde151 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 41feffffaf97b3f..d9c8bb4d26a2365 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -32,93 +32,6 @@ using namespace mlir;
 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
-/// Expands the given list of `paths` to a list of `.mlir` files.
-///
-/// Each entry in `paths` may either be a regular file, in which case it ends up
-/// in the result list, or a directory, in which case all (regular) `.mlir`
-/// files in that directory are added. Any other file types lead to a failure.
-LogicalResult transform::detail::expandPathsToMLIRFiles(
-    ArrayRef<std::string> &paths, MLIRContext *context,
-    SmallVectorImpl<std::string> &fileNames) {
-  for (const std::string &path : paths) {
-    auto loc = FileLineColLoc::get(context, path, 0, 0);
-
-    if (llvm::sys::fs::is_regular_file(path)) {
-      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
-      fileNames.push_back(path);
-      continue;
-    }
-
-    if (!llvm::sys::fs::is_directory(path)) {
-      return emitError(loc)
-             << "'" << path << "' is neither a file nor a directory";
-    }
-
-    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
-
-    std::error_code ec;
-    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
-         it != itEnd && !ec; it.increment(ec)) {
-      const std::string &fileName = it->path();
-
-      if (it->type() != llvm::sys::fs::file_type::regular_file) {
-        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
-                          << "'\n");
-        continue;
-      }
-
-      if (!StringRef(fileName).endswith(".mlir")) {
-        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
-                          << "' because it does not end with '.mlir'\n");
-        continue;
-      }
-
-      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
-      fileNames.push_back(fileName);
-    }
-
-    if (ec)
-      return emitError(loc) << "error while opening files in '" << path
-                            << "': " << ec.message();
-  }
-
-  return success();
-}
-
-LogicalResult transform::detail::parseTransformModuleFromFile(
-    MLIRContext *context, llvm::StringRef transformFileName,
-    OwningOpRef<ModuleOp> &transformModule) {
-  if (transformFileName.empty()) {
-    LLVM_DEBUG(
-        DBGS() << "no transform file name specified, assuming the transform "
-                  "module is embedded in the IR next to the top-level\n");
-    return success();
-  }
-  // Parse transformFileName content into a ModuleOp.
-  std::string errorMessage;
-  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
-  if (!memoryBuffer) {
-    return emitError(FileLineColLoc::get(
-               StringAttr::get(context, transformFileName), 0, 0))
-           << "failed to open transform file: " << errorMessage;
-  }
-  // Tell sourceMgr about this buffer, the parser will pick it up.
-  llvm::SourceMgr sourceMgr;
-  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
-  transformModule =
-      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
-  return mlir::verify(*transformModule);
-}
-
-ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
-  auto preloadedLibraryRange =
-      context->getOrLoadDialect<transform::TransformDialect>()
-          ->getLibraryModules();
-  if (!preloadedLibraryRange.empty())
-    return *preloadedLibraryRange.begin();
-  return ModuleOp();
-}
-
 transform::TransformOpInterface
 transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
                                            StringRef entryPoint) {
@@ -145,268 +58,6 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
   return nullptr;
 }
 
-LogicalResult transform::detail::assembleTransformLibraryFromPaths(
-    MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
-    OwningOpRef<ModuleOp> &transformModule) {
-  // Assemble list of library files.
-  SmallVector<std::string> libraryFileNames;
-  if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
-                                            libraryFileNames)))
-    return failure();
-
-  // Parse modules from library files.
-  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
-  for (const std::string &libraryFileName : libraryFileNames) {
-    OwningOpRef<ModuleOp> parsedLibrary;
-    if (failed(transform::detail::parseTransformModuleFromFile(
-            context, libraryFileName, parsedLibrary)))
-      return failure();
-    parsedLibraries.push_back(std::move(parsedLibrary));
-  }
-
-  // Merge parsed libraries into one module.
-  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
-  OwningOpRef<ModuleOp> mergedParsedLibraries =
-      ModuleOp::create(loc, "__transform");
-  {
-    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
-                                         UnitAttr::get(context));
-    IRRewriter rewriter(context);
-    // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
-    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
-      if (failed(transform::detail::mergeSymbolsInto(
-              mergedParsedLibraries.get(), std::move(parsedLibrary))))
-        return mergedParsedLibraries->emitError()
-               << "failed to verify merged transform module";
-    }
-  }
-
-  transformModule = std::move(mergedParsedLibraries);
-  return success();
-}
-
-/// Return whether `func1` can be merged into `func2`. For that to work
-/// `func1` has to be a declaration (aka has to be external) and `func2`
-/// either has to be a declaration as well, or it has to be public (otherwise,
-/// it wouldn't be visible by `func1`).
-static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
-  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
-}
-
-/// Merge `func1` into `func2`. The two ops must be inside the same parent op
-/// and mergable according to `canMergeInto`. The function erases `func1` such
-/// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
-                               FunctionOpInterface func2) {
-  assert(canMergeInto(func1, func2));
-  assert(func1->getParentOp() == func2->getParentOp() &&
-         "expected func1 and func2 to be in the same parent op");
-
-  // Check that function signatures match.
-  if (func1.getFunctionType() != func2.getFunctionType()) {
-    return func1.emitError()
-           << "external definition has a mismatching signature ("
-           << func2.getFunctionType() << ")";
-  }
-
-  // Check and merge argument attributes.
-  MLIRContext *context = func1->getContext();
-  auto *td = context->getLoadedDialect<transform::TransformDialect>();
-  StringAttr consumedName = td->getConsumedAttrName();
-  StringAttr readOnlyName = td->getReadOnlyAttrName();
-  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
-    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
-    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
-    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
-    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
-    if (!isExternalConsumed && !isExternalReadonly) {
-      if (isConsumed)
-        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
-      else if (isReadonly)
-        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
-      continue;
-    }
-
-    if ((isExternalConsumed && !isConsumed) ||
-        (isExternalReadonly && !isReadonly)) {
-      return func1.emitError()
-             << "external definition has mismatching consumption "
-                "annotations for argument #"
-             << i;
-    }
-  }
-
-  // `func1` is the external one, so we can remove it.
-  assert(func1.isExternal());
-  func1->erase();
-
-  return success();
-}
-
-LogicalResult
-transform::detail::mergeSymbolsInto(Operation *target,
-                                    OwningOpRef<Operation *> other) {
-  assert(target->hasTrait<OpTrait::SymbolTable>() &&
-         "requires target to implement the 'SymbolTable' trait");
-  assert(other->hasTrait<OpTrait::SymbolTable>() &&
-         "requires target to implement the 'SymbolTable' trait");
-
-  SymbolTable targetSymbolTable(target);
-  SymbolTable otherSymbolTable(*other);
-
-  // Step 1:
-  //
-  // Rename private symbols in both ops in order to resolve conflicts that can
-  // be resolved that way.
-  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
-  // TODO: Do we *actually* need to test in both directions?
-  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
-           SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
-           SmallVector<SymbolTable *, 2>{&otherSymbolTable,
-                                         &targetSymbolTable})) {
-    Operation *symbolTableOp = symbolTable->getOp();
-    for (Operation &op : symbolTableOp->getRegion(0).front()) {
-      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
-      if (!symbolOp)
-        continue;
-      StringAttr name = symbolOp.getNameAttr();
-      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
-
-      // Check if there is a colliding op in the other module.
-      auto collidingOp =
-          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
-      if (!collidingOp)
-        continue;
-
-      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
-
-      // Collisions are fine if both opt are functions and can be merged.
-      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
-          collidingFuncOp =
-              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
-          funcOp && collidingFuncOp) {
-        if (canMergeInto(funcOp, collidingFuncOp) ||
-            canMergeInto(collidingFuncOp, funcOp)) {
-          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
-                                     "will be merged\n");
-          continue;
-        }
-
-        // If they can't be merged, proceed like any other collision.
-        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
-      }
-
-      // Collision can be resolved by renaming if one of the ops is private.
-      auto renameToUnique =
-          [&](SymbolOpInterface op, SymbolOpInterface otherOp,
-              SymbolTable &symbolTable,
-              SymbolTable &otherSymbolTable) -> LogicalResult {
-        LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
-        FailureOr<StringAttr> maybeNewName =
-            symbolTable.renameToUnique(op, {&otherSymbolTable});
-        if (failed(maybeNewName)) {
-          InFlightDiagnostic diag = op->emitError("failed to rename symbol");
-          diag.attachNote(otherOp->getLoc())
-              << "attempted renaming due to collision with this op";
-          return diag;
-        }
-        LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
-                          << "\n");
-        return success();
-      };
-
-      if (symbolOp.isPrivate()) {
-        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
-                                  *otherSymbolTable)))
-          return failure();
-        continue;
-      }
-      if (collidingOp.isPrivate()) {
-        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
-                                  *symbolTable)))
-          return failure();
-        continue;
-      }
-      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
-      InFlightDiagnostic diag = symbolOp.emitError()
-                                << "doubly defined symbol @" << name.getValue();
-      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
-      return diag;
-    }
-  }
-
-  // TODO: This duplicates pass infrastructure. We should split this pass into
-  //       several and let the pass infrastructure do the verification.
-  for (auto *op : SmallVector<Operation *>{target, *other}) {
-    if (failed(mlir::verify(op)))
-      return op->emitError() << "failed to verify input op after renaming";
-  }
-
-  // Step 2:
-  //
-  // Move all ops from `other` into target and merge public symbols.
-  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
-  {
-    SmallVector<SymbolOpInterface> opsToMove;
-    for (Operation &op : other->getRegion(0).front()) {
-      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
-        opsToMove.push_back(symbol);
-    }
-
-    for (SymbolOpInterface op : opsToMove) {
-      // Remember potentially colliding op in the target module.
-      auto collidingOp = cast_or_null<SymbolOpInterface>(
-          targetSymbolTable.lookup(op.getNameAttr()));
-
-      // Move op even if we get a collision.
-      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
-      op->moveBefore(&target->getRegion(0).front(),
-                     target->getRegion(0).front().end());
-
-      // If there is no collision, we are done.
-      if (!collidingOp) {
-        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
-        continue;
-      }
-
-      // The two colliding ops must both be functions because we have already
-      // emitted errors otherwise earlier.
-      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
-      auto collidingFuncOp =
-          cast<FunctionOpInterface>(collidingOp.getOperation());
-
-      // Both ops are in the target module now and can be treated
-      // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
-      // `collidingFuncOp`.
-      if (!canMergeInto(funcOp, collidingFuncOp)) {
-        std::swap(funcOp, collidingFuncOp);
-      }
-      assert(canMergeInto(funcOp, collidingFuncOp));
-
-      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
-                              << collidingFuncOp.getLoc() << ":\n"
-                              << collidingFuncOp << "\n");
-
-      // Update symbol table. This works with or without the previous `swap`.
-      targetSymbolTable.remove(funcOp);
-      targetSymbolTable.insert(collidingFuncOp);
-      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
-
-      // Do the actual merging.
-      if (failed(mergeInto(funcOp, collidingFuncOp))) {
-        return failure();
-      }
-    }
-  }
-
-  if (failed(mlir::verify(target)))
-    return target->emitError()
-           << "failed to verify target op after merging symbols";
-
-  LLVM_DEBUG(DBGS() << "done merging ops\n");
-  return success();
-}
-
 LogicalResult transform::applyTransformNamedSequence(
     Operation *payload, ModuleOp transformModule,
     const TransformOptions &options, StringRef entryPoint) {

>From 3c0d1410c11746291aebeb597041045f7aa4f056 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 16 Oct 2023 12:47:50 +0000
Subject: [PATCH 2/2] [mlir][transform] Handle multiple library preloading
 passes.

The transform dialect stores a "library module" that the preload pass
can populate. Until now, each pass registered an additional module by
simply pushing it to a vector; however, the interpreter only used the
first of them. This commit turns the registration into "loading", i.e.,
each newly added module gets merged into the existing one. This allows
the loading to be split into several passes, and using the library in
the interpreter now takes all of them into account. While this design
avoids repeated merging every time the library is accessed, it requires
that the implementation of merging modules lives in the
`TransformDialect` target (since it at the dialect depend on each
other).

This resolves #69111.
---
 .../Dialect/Transform/IR/TransformDialect.td  | 34 +++++++++----------
 .../Dialect/Transform/IR/TransformDialect.cpp | 23 +++++++++----
 .../Transforms/PreloadLibraryPass.cpp         |  4 ++-
 .../Dialect/Transform/preload-library.mlir    | 12 +++++++
 mlir/unittests/Dialect/Transform/Preload.cpp  |  2 +-
 5 files changed, 49 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index ad6804673b770ca..211663258bfb133 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -78,20 +78,19 @@ def Transform_Dialect : Dialect {
     using ExtensionTypePrintingHook =
         std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
 
-    /// Appends the given module as a transform symbol library available to
-    /// all dialect users.
-    void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
-                                library) {
-      libraryModules.push_back(std::move(library));
-    }
-
-    /// Returns a range of registered library modules.
-    auto getLibraryModules() const {
-      return ::llvm::map_range(
-          libraryModules,
-          [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
-        return library.get();
-      });
+    /// Loads the given module into the transform symbol library module.
+    void initializeLibraryModule();
+
+    /// Loads the given module into the transform symbol library module.
+    LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+                                        library);
+
+    /// Returns the transform symbol library module available to all dialect
+    /// users.
+    ModuleOp getLibraryModule() const {
+      if (libraryModule)
+        return libraryModule.get();
+      return ModuleOp();
     }
 
   private:
@@ -153,10 +152,9 @@ def Transform_Dialect : Dialect {
     ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
         typePrintingHooks;
 
-    /// Modules containing symbols, e.g. named sequences, that will be
-    /// resolved by the interpreter when used.
-    ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
-        libraryModules;
+    /// Module containing symbols, e.g. named sequences, that will be resolved
+    /// by the interpreter when used.
+    ::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
   }];
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 1f2f548f8e471ce..e12176be57b0eb7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -73,6 +73,7 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
   initializeTypes();
+  initializeLibraryModule();
 }
 
 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
@@ -97,6 +98,20 @@ void transform::TransformDialect::printType(Type type,
   it->getSecond()(type, printer);
 }
 
+void transform::TransformDialect::initializeLibraryModule() {
+  MLIRContext *context = getContext();
+  auto loc =
+      FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
+  libraryModule = ModuleOp::create(loc, "__transform_library");
+  libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
+                               UnitAttr::get(context));
+}
+
+LogicalResult transform::TransformDialect::loadIntoLibraryModule(
+    ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
+  return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
+}
+
 void transform::TransformDialect::reportDuplicateTypeRegistration(
     StringRef mnemonic) {
   std::string buffer;
@@ -260,12 +275,8 @@ LogicalResult transform::detail::parseTransformModuleFromFile(
 }
 
 ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
-  auto preloadedLibraryRange =
-      context->getOrLoadDialect<transform::TransformDialect>()
-          ->getLibraryModules();
-  if (!preloadedLibraryRange.empty())
-    return *preloadedLibraryRange.begin();
-  return ModuleOp();
+  return context->getOrLoadDialect<transform::TransformDialect>()
+      ->getLibraryModule();
 }
 
 LogicalResult transform::detail::assembleTransformLibraryFromPaths(
diff --git a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
index 2075f968b2eb8b8..2d89b98350b7c05 100644
--- a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
@@ -32,7 +32,9 @@ class PreloadLibraryPass
     // TODO: investigate using a resource blob if some ownership mode allows it.
     auto *dialect =
         getContext().getOrLoadDialect<transform::TransformDialect>();
-    dialect->registerLibraryModule(std::move(mergedParsedLibraries));
+    if (failed(
+            dialect->loadIntoLibraryModule(std::move(mergedParsedLibraries))))
+      signalPassFailure();
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Transform/preload-library.mlir b/mlir/test/Dialect/Transform/preload-library.mlir
index 61d22252dc61dfd..9d51fdbfd415406 100644
--- a/mlir/test/Dialect/Transform/preload-library.mlir
+++ b/mlir/test/Dialect/Transform/preload-library.mlir
@@ -3,6 +3,18 @@
 // RUN:   -transform-interpreter=entry-point=private_helper \
 // RUN:   -split-input-file -verify-diagnostics
 
+// RUN: mlir-opt %s \
+// RUN:   -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN:   -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN:   -transform-interpreter=entry-point=private_helper \
+// RUN:   -split-input-file -verify-diagnostics
+
+// RUN: mlir-opt %s \
+// RUN:   -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
+// RUN:   -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
+// RUN:   -transform-interpreter=entry-point=private_helper \
+// RUN:   -split-input-file -verify-diagnostics
+
 // expected-remark @below {{message}}
 module {}
 
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
index d3c3044e0e0f776..2c9c15ff94433e2 100644
--- a/mlir/unittests/Dialect/Transform/Preload.cpp
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -67,7 +67,7 @@ TEST(Preload, ContextPreloadConstructedLibrary) {
   OwningOpRef<ModuleOp> transformLibrary =
       parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
   EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
-  dialect->registerLibraryModule(std::move(transformLibrary));
+  dialect->loadIntoLibraryModule(std::move(transformLibrary));
 
   ModuleOp retrievedTransformLibrary =
       transform::detail::getPreloadedTransformModule(&context);



More information about the Mlir-commits mailing list