[Mlir-commits] [mlir] [mlir][transform] Allow passing various library files to interpreter. (PR #67120)

Ingo Müller llvmlistbot at llvm.org
Mon Sep 25 01:21:21 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67120

>From 2b7d6b4d4eaf5fe7159f93196e17423771d21fb7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 15:10:30 +0000
Subject: [PATCH 1/3] [mlir][transform] Fix handling of transitive include in
 interpreter.

Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR extends the loading missing as follows: in
`defineDeclaredSymbols`, not only are the definitions inserted that are
forward-declared in the main module, but any such inserted definition is
scanned for further dependencies, and those are processed in the same
way as the forward-declarations from the main module.
---
 .../TransformInterpreterPassBase.cpp          | 72 ++++++++++++++++---
 ...reter-external-symbol-decl-transitive.mlir | 27 +++++++
 .../test-interpreter-external-symbol-def.mlir |  5 ++
 3 files changed, 95 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..3c993e417b67efe 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
   auto readOnlyName =
       StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
 
+  // Collect symbols missing in the block.
+  SmallVector<SymbolOpInterface> missingSymbols;
+  LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
   for (Operation &op : llvm::make_early_inc_range(block)) {
     LLVM_DEBUG(DBGS() << op << "\n");
     auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       continue;
     if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
       continue;
+    LLVM_DEBUG(DBGS() << "  -> symbol missing\n");
+    missingSymbols.push_back(symbol);
+  }
 
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+  // Resolve missing symbols until they are all resolved.
+  while (!missingSymbols.empty()) {
+    SymbolOpInterface symbol = missingSymbols.pop_back_val();
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
+                      << symbol.getNameAttr().getValue() << ": ");
+    SymbolTable definitionsSymbolTable(definitions);
+    Operation *externalSymbol =
+        definitionsSymbolTable.lookup(symbol.getNameAttr());
     if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
         externalSymbol->getRegion(0).empty()) {
       LLVM_DEBUG(llvm::dbgs() << "not found\n");
       continue;
     }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
     auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
     if (!symbolFunc || !externalSymbolFunc) {
       LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
+                            << externalSymbol->getLoc() << "\n");
     if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
       return symbolFunc.emitError()
              << "external definition has a mismatching signature ("
@@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
       }
     }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
+    OpBuilder builder(symbol);
+    builder.setInsertionPoint(symbol);
+    Operation *newSymbol = builder.clone(*externalSymbol);
+    builder.setInsertionPoint(newSymbol);
     symbol->erase();
+
+    LLVM_DEBUG(DBGS() << "scanning definition of @"
+                      << externalSymbolFunc.getNameAttr().getValue()
+                      << " for symbol usages\n");
+    externalSymbolFunc.walk([&](CallOpInterface callOp) {
+      LLVM_DEBUG(DBGS() << "  found symbol usage in:\n" << callOp << "\n");
+      CallInterfaceCallable callable = callOp.getCallableForCallee();
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(DBGS() << "    not a 'SymbolRefAttr'\n");
+        return WalkResult::advance();
+      }
+
+      StringRef callableSymbol =
+          cast<SymbolRefAttr>(callable).getLeafReference();
+      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbol
+                        << " in definitions: ");
+
+      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+      if (!isa<SymbolRefAttr>(callable)) {
+        LLVM_DEBUG(llvm::dbgs() << "not found\n");
+        return WalkResult::advance();
+      }
+      LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
+                              << callableOp->getLoc() << "\n");
+
+      if (!block.getParent() || !block.getParent()->getParentOp()) {
+        LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+        return WalkResult::advance();
+      }
+
+      SymbolTable targetSymbolTable(block.getParent()->getParentOp());
+      if (targetSymbolTable.lookup(callableSymbol)) {
+        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbol
+                          << " already present in target\n");
+        return WalkResult::advance();
+      }
+
+      LLVM_DEBUG(DBGS() << "    cloning op into target\n");
+      builder.clone(*callableOp);
+
+      return WalkResult::advance();
+    });
   }
 
   return success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
new file mode 100644
index 000000000000000..0e9fa7c59bc4155
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @bar named sequence is provided in another file. It
+// will be included because of the pass option. That sequence uses another named
+// sequence @foo, which should be made available here. Repeated application of
+// the same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK-DAG: transform.named_sequence @foo
+  // CHECK-DAG: transform.named_sequence @bar
+  transform.named_sequence private @bar(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..9aa2d46d5abb995 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,6 +1,11 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @bar(%arg0: !transform.any_op) {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield

>From 48a81af250ae8a8789e04b681635090927d817b6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 25 Sep 2023 08:11:03 +0000
Subject: [PATCH 2/3] Minor improvements of variable names and debug output.

---
 .../Transforms/TransformInterpreterPassBase.cpp  | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 3c993e417b67efe..475368f0f406ab9 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -388,19 +388,19 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
                       << externalSymbolFunc.getNameAttr().getValue()
                       << " for symbol usages\n");
     externalSymbolFunc.walk([&](CallOpInterface callOp) {
-      LLVM_DEBUG(DBGS() << "  found symbol usage in:\n" << callOp << "\n");
+      LLVM_DEBUG(DBGS() << "  call op in:\n" << callOp << "\n");
       CallInterfaceCallable callable = callOp.getCallableForCallee();
       if (!isa<SymbolRefAttr>(callable)) {
-        LLVM_DEBUG(DBGS() << "    not a 'SymbolRefAttr'\n");
+        LLVM_DEBUG(DBGS() << "    not a symbol usage\n");
         return WalkResult::advance();
       }
 
-      StringRef callableSymbol =
+      StringRef callableSymbolName =
           cast<SymbolRefAttr>(callable).getLeafReference();
-      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbol
+      LLVM_DEBUG(DBGS() << "    looking for @" << callableSymbolName
                         << " in definitions: ");
 
-      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+      Operation *callableOp = definitionsSymbolTable.lookup(callableSymbolName);
       if (!isa<SymbolRefAttr>(callable)) {
         LLVM_DEBUG(llvm::dbgs() << "not found\n");
         return WalkResult::advance();
@@ -409,13 +409,13 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
                               << callableOp->getLoc() << "\n");
 
       if (!block.getParent() || !block.getParent()->getParentOp()) {
-        LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+        LLVM_DEBUG(DBGS() << "could not get parent op of provided block");
         return WalkResult::advance();
       }
 
       SymbolTable targetSymbolTable(block.getParent()->getParentOp());
-      if (targetSymbolTable.lookup(callableSymbol)) {
-        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbol
+      if (targetSymbolTable.lookup(callableSymbolName)) {
+        LLVM_DEBUG(DBGS() << "    symbol @" << callableSymbolName
                           << " already present in target\n");
         return WalkResult::advance();
       }

>From e4193077e2aa8feab4a7cf098dd32c1d36307df2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 12:05:38 +0000
Subject: [PATCH 3/3] [mlir][transform] Allow passing various library files to
 interpreter.

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
---
 .../Transforms/TransformInterpreterPassBase.h |  43 +++-
 .../TransformInterpreterPassBase.cpp          | 227 ++++++++++++++----
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir |   2 +-
 ...ter-external-symbol-decl-and-schedule.mlir |   4 +-
 ...-interpreter-external-symbol-decl-dir.mlir |  30 +++
 ...erpreter-external-symbol-decl-invalid.mlir |   4 +-
 ...reter-external-symbol-decl-transitive.mlir |   6 +-
 ...test-interpreter-external-symbol-decl.mlir |   6 +-
 .../definitions-self-contained.mlir}          |   0
 .../definitions-with-unresolved.mlir          |  10 +
 .../Dialect/Transform/match_batch_matmul.mlir |   2 +-
 .../Dialect/Transform/match_matmul.mlir       |   2 +-
 .../TestTransformDialectInterpreter.cpp       |  17 +-
 13 files changed, 277 insertions(+), 76 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
 rename mlir/test/Dialect/Transform/{test-interpreter-external-symbol-def.mlir => test-interpreter-library/definitions-self-contained.mlir} (100%)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..5d021da90e9a594 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,8 @@ namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
 LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +49,8 @@ LogicalResult interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -62,9 +64,14 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the name of the file containing
+///   - transformLibraryFileNames: if non-empty, the names of files containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to replace declarations and must be
+///     unique within all files provided by this and the next option.
+///   - transformLibraryDirNames: if non-empty, the name of directories
+///     containing definitions of external symbols referenced in the transform
+///     script. These definitions will be used to replace declarations and must
+///     be unique within all files provided by this and the previous option.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -115,17 +122,30 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
-    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
+#define REQUIRE_PASS_LIST_OPTION(NAME)                                         \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::ListOption<std::string>>,                                      \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_LIST_OPTION(transformLibraryFileNames);
+    REQUIRE_PASS_LIST_OPTION(transformLibraryDirNames);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    StringRef transformLibraryFileName =
-        static_cast<Concrete *>(this)->transformLibraryFileName;
+    ArrayRef<std::string> transformLibraryFileNames =
+        static_cast<Concrete *>(this)->transformLibraryFileNames;
+    ArrayRef<std::string> transformLibraryDirNames =
+        static_cast<Concrete *>(this)->transformLibraryDirNames;
     return detail::interpreterBaseInitializeImpl(
-        context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule,
+        context, transformFileName, transformLibraryFileNames,
+        transformLibraryDirNames, sharedTransformModule, transformLibraryModule,
         [this](OpBuilder &builder, Location loc) {
           return static_cast<Concrete *>(this)->constructTransformModule(
               builder, loc);
@@ -159,8 +179,9 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
             op, pass->getArgument(), sharedTransformModule,
             transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->transformLibraryFileName, pass->debugPayloadRootTag,
-            pass->debugTransformRootTag, binaryName)) ||
+            pass->transformLibraryFileNames, pass->transformLibraryDirNames,
+            pass->debugPayloadRootTag, pass->debugTransformRootTag,
+            binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
     }
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 475368f0f406ab9..1855f424f61a4ef 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -161,15 +162,24 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
-               const Pass::Option<std::string> &transformLibraryFileName,
+               const Pass::ListOption<std::string> &transformLibraryFileNames,
+               const Pass::ListOption<std::string> &transformLibraryDirNames,
                StringRef binaryName) {
-  std::string transformLibraryOption = "";
-  if (!transformLibraryFileName.empty()) {
-    transformLibraryOption =
-        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
-                      transformLibraryFileName.getValue())
-            .str();
+  std::string transformLibraryOptions = "";
+  {
+    llvm::raw_string_ostream optionStream(transformLibraryOptions);
+    if (!transformLibraryFileNames.empty()) {
+      optionStream << " " << transformLibraryFileNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryFileNames, optionStream, ",");
+      optionStream << "'";
+    }
+    if (!transformLibraryDirNames.empty()) {
+      optionStream << " " << transformLibraryDirNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryDirNames, optionStream, ",");
+      optionStream << "'";
+    }
   }
+
   os << llvm::formatv(
       "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
@@ -180,7 +190,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      transformLibraryOption, binaryName);
+      transformLibraryOptions, binaryName);
   return os;
 }
 
@@ -200,7 +210,8 @@ void saveReproToTempFile(
     llvm::raw_ostream &os, Operation *target, Operation *transform,
     StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
@@ -227,7 +238,8 @@ void saveReproToTempFile(
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
                  debugPayloadRootTag, debugTransformRootTag,
-                 transformLibraryFileName, binaryName)
+                 transformLibraryFileNames, transformLibraryDirNames,
+                 binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -238,7 +250,8 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -279,10 +292,10 @@ static void performOptionalDebugActions(
 
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
-    printReproCall(llvm::dbgs() << "cat <<EOF | ",
-                   root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag,
-                   transformLibraryFileName, binaryName)
+    printReproCall(
+        llvm::dbgs() << "cat <<EOF | ", root->getName().getStringRef(),
+        passName, debugPayloadRootTag, debugTransformRootTag,
+        transformLibraryFileNames, transformLibraryDirNames, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -292,7 +305,8 @@ static void performOptionalDebugActions(
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
                         debugPayloadRootTag, debugTransformRootTag,
-                        transformLibraryFileName, binaryName);
+                        transformLibraryFileNames, transformLibraryDirNames,
+                        binaryName);
   });
 
   // Remove temporary attributes if they were set.
@@ -437,7 +451,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -494,7 +509,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     }
     if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
                                      libraryModule->get())))
-      return failure();
+      return emitError(
+          transformRoot->getLoc(),
+          "failed to replace declarations with definitions in transform root");
   }
 
   // Step 4
@@ -503,7 +520,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              transformLibraryFileName, binaryName);
+                              transformLibraryFileNames,
+                              transformLibraryDirNames, binaryName);
 
   // Step 5
   // ------
@@ -514,51 +532,166 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
-    return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
-    return failure();
+  auto unknownLoc = UnknownLoc::get(context);
+
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile)))
+      return emitError(loc, "failed to parse transform module");
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
+      return emitError(loc, "failed to verify transform module");
+  }
 
-  OwningOpRef<ModuleOp> parsedLibrary;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
-    return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
-    return failure();
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  libraryFileNames.append(transformLibraryFileNames.begin(),
+                          transformLibraryFileNames.end());
+
+  for (const std::string &dirName : transformLibraryDirNames) {
+    LLVM_DEBUG(DBGS() << "Opening files in '" << dirName << "':\n");
+
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(dirName, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
+                          << "'\n");
+        continue;
+      }
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+      if (!StringRef(fileName).endswith(".mlir")) {
+        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
+                          << "' because it does not end with '.mlir'\n");
+        continue;
+      }
+
+      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
+      libraryFileNames.push_back(fileName);
+    }
+
+    if (ec)
+      return emitError(unknownLoc, Twine("error while opening files in '") +
+                                       dirName + "': " + ec.message());
+  }
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, libraryFileName,
+                                            parsedLibrary)))
+      return emitError(loc, "failed to parse transform library module");
+    if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+      return emitError(loc, "failed to verify transform library module");
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Build shared transform module.
+  if (moduleFromFile) {
+    sharedTransformModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
   } else if (moduleBuilder) {
     // TODO: better location story.
-    auto location = UnknownLoc::get(context);
     auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        ModuleOp::create(location, "__transform"));
+        ModuleOp::create(unknownLoc, "__transform"));
 
     OpBuilder b(context);
     b.setInsertionPointToEnd(localModule->get().getBody());
-    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+    if (std::optional<LogicalResult> result = moduleBuilder(b, unknownLoc)) {
       if (failed(*result))
-        return failure();
-      module = std::move(localModule);
+        return emitError(unknownLoc,
+                         "failed to create shared transform module");
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (parsedLibraries.empty())
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
-      return failure();
+  // Merge parsed libraries into one module.
+  // TODO: better location story.
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(unknownLoc, "__transform");
+  {
+    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                         UnitAttr::get(context));
+    IRRewriter rewriter(context);
+    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+
+      // Handle symbol collisions.
+      SymbolTable symbolTable(mergedParsedLibraries.get());
+      for (Operation &op : parsedLibrary.get()) {
+        if (!isa<SymbolOpInterface>(op))
+          continue;
+        auto symbolOp = cast<SymbolOpInterface>(op);
+
+        Operation *clashingOp = symbolTable.lookup(symbolOp.getNameAttr());
+        if (!clashingOp)
+          continue;
+
+        // There is a clash -- handle it.
+
+        auto reportClashingSymbols = [&]() {
+          InFlightDiagnostic diag = emitError(
+              symbolOp->getLoc(),
+              "symbol already defined while loading transform library files");
+          diag.attachNote(clashingOp->getLoc())
+              << "previous occurrence of this symbol";
+          return diag;
+        };
+
+        // Clash with other op types: emit error.
+        if (!isa<NamedSequenceOp>(clashingOp) ||
+            !isa<NamedSequenceOp>(symbolOp))
+          return reportClashingSymbols();
+
+        // Clash with other named sequence op.
+        auto namedSequenceOp = cast<NamedSequenceOp>(symbolOp);
+        auto clashingNamedSequenceOp = cast<NamedSequenceOp>(clashingOp);
+        if (namedSequenceOp.getBody().empty()) {
+          // This named sequence op is a forward declaration: erase it.
+          rewriter.eraseOp(namedSequenceOp);
+        } else if (clashingNamedSequenceOp.getBody().empty()) {
+          // The clashing op is a forward declaration: erase it.
+          rewriter.eraseOp(clashingOp);
+        } else {
+          // Both ops are definitions: report the clash.
+          return reportClashingSymbols();
+        }
+      }
+
+      rewriter.inlineBlockBefore(parsedLibrary->getBody(),
+                                 mergedParsedLibraries->getBody(),
+                                 mergedParsedLibraries->getBody()->end());
+
+      if (failed(mlir::verify(*mergedParsedLibraries)))
+        return emitError(unknownLoc,
+                         "failed to verify merged transform module");
+    }
+  }
+
+  // Use parsed libaries to resolve symbols in shared transform module or return
+  // as separate library module.
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     mergedParsedLibraries.get())))
+      return emitError(unknownLoc, "failed to replace declarations with "
+                                   "definitions in shared transform module");
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(mergedParsedLibraries));
   }
   return success();
 }
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..c1253cdb92d5002 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,7 @@
 
 // RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
 
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-names=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
 // RUN:   -test-transform-dialect-erase-schedule -cse \
 // RUN: | FileCheck %s
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..27f7270c95e3b34 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
 // The external transform script has a declaration to the named sequence @foo,
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
new file mode 100644
index 000000000000000..12dd1d5265c256c
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK-DAG: transform.named_sequence @baz
+  // CHECK-DAG: test_print_remark_at_operand %{{.*}}, "message"
+
+  // CHECK-DAG: transform.named_sequence @foo
+  // CHECK-DAG: transform.include @foo
+
+  transform.named_sequence @baz(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @baz failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..ea59019215e189b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
 // The definition of the @foo named sequence is provided in another file. It
@@ -8,6 +8,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
   transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
 
+  // expected-error @below {{failed to replace declarations with definitions in transform root}}
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
     include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -32,6 +33,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
   transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
 
+  // expected-error @below {{failed to replace declarations with definitions in transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
index 0e9fa7c59bc4155..21411328cc4d38b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @bar named sequence is provided in another file. It
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..f4627f441e8e5ca 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @foo named sequence is provided in another file. It
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
similarity index 100%
rename from mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
rename to mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
diff --git a/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
new file mode 100644
index 000000000000000..d5ee1083e36ccad
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly})
+
+  transform.named_sequence @baz(%arg0: !transform.any_op) {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..c7c055c7567e6f3 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..3e06f2ca0895841 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..28e5256385efaea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,8 @@ class TestTransformDialectInterpreterPass
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
-            transformFileName, transformLibraryFileName, debugPayloadRootTag,
+            transformFileName, transformLibraryFileNames,
+            transformLibraryDirNames, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -216,12 +217,16 @@ class TestTransformDialectInterpreterPass
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
-  Option<std::string> transformLibraryFileName{
-      *this, "transform-library-file-name", llvm::cl::init(""),
+  ListOption<std::string> transformLibraryFileNames{
+      *this, "transform-library-file-names", llvm::cl::ZeroOrMore,
+      llvm::cl::desc("Optional filenames containing transform dialect symbol "
+                     "definitions to be injected into the transform module.")};
+  ListOption<std::string> transformLibraryDirNames{
+      *this, "transform-library-dir-names", llvm::cl::ZeroOrMore,
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
-
+          "Optional directories containing transform dialect symbol "
+          "definitions to be injected into the transform module. All '.mlir' "
+          "files rooted under this directory will be loaded.")};
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),
       llvm::cl::desc("test the generation of the transform module during pass "



More information about the Mlir-commits mailing list