[Mlir-commits] [mlir] [mlir][transform] Allow passing various library files to interpreter. (PR #67120)

Ingo Müller llvmlistbot at llvm.org
Fri Oct 6 02:06:50 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67120

>From 1833cbb566b19b83d73255538feffdd4acde1775 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 12:05:38 +0000
Subject: [PATCH 1/2] [mlir][transform] Allow passing various library files to
 interpreter.

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
---
 .../Transforms/TransformInterpreterPassBase.h |  31 ++--
 .../TransformInterpreterPassBase.cpp          | 144 ++++++++++++++----
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir |   2 +-
 ...ter-external-symbol-decl-and-schedule.mlir |   4 +-
 ...-interpreter-external-symbol-decl-dir.mlir |  28 ++++
 ...erpreter-external-symbol-decl-invalid.mlir |   5 +-
 ...test-interpreter-external-symbol-decl.mlir |   4 +-
 .../definitions-self-contained.mlir}          |   0
 .../definitions-with-unresolved.mlir          |  10 ++
 .../Dialect/Transform/match_batch_matmul.mlir |   2 +-
 .../Dialect/Transform/match_matmul.mlir       |   2 +-
 .../TestTransformDialectInterpreter.cpp       |   8 +-
 12 files changed, 186 insertions(+), 54 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
 rename mlir/test/Dialect/Transform/{test-interpreter-external-symbol-def.mlir => test-interpreter-library/definitions-self-contained.mlir} (100%)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index a6f0dddebd7eacf..16ef0bc6a739200 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,7 @@ namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
 LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryPaths,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +48,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryPaths,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -62,11 +62,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the module in this file will be
+///   - transformLibraryPaths: if non-empty, the modules in these files will be
 ///     merged into the main transform script run by the interpreter before
 ///     execution. This allows to provide definitions for external functions
-///     used in the main script. Other public symbols in the library module may
-///     lead to collisions with public symbols in the main script.
+///     used in the main script. Other public symbols in the library modules may
+///     lead to collisions with public symbols in the main script and among each
+///     other.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -118,16 +119,26 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
-    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
+#define REQUIRE_PASS_LIST_OPTION(NAME)                                         \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::ListOption<std::string>>,                                      \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_LIST_OPTION(transformLibraryPaths);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    StringRef transformLibraryFileName =
-        static_cast<Concrete *>(this)->transformLibraryFileName;
+    ArrayRef<std::string> transformLibraryPaths =
+        static_cast<Concrete *>(this)->transformLibraryPaths;
     return detail::interpreterBaseInitializeImpl(
-        context, transformFileName, transformLibraryFileName,
+        context, transformFileName, transformLibraryPaths,
         sharedTransformModule, transformLibraryModule,
         [this](OpBuilder &builder, Location loc) {
           return static_cast<Concrete *>(this)->constructTransformModule(
@@ -162,7 +173,7 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
             op, pass->getArgument(), sharedTransformModule,
             transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->transformLibraryFileName, pass->debugPayloadRootTag,
+            pass->transformLibraryPaths, pass->debugPayloadRootTag,
             pass->debugTransformRootTag, binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 68a735e7ef8e0b3..765061b2a6532fe 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
                     Operation *transform, StringRef passName,
                     const Pass::Option<std::string> &debugPayloadRootTag,
                     const Pass::Option<std::string> &debugTransformRootTag,
-                    const Pass::Option<std::string> &transformLibraryFileName,
+                    const Pass::ListOption<std::string> &transformLibraryPaths,
                     StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
@@ -231,7 +232,7 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryPaths,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -284,7 +285,7 @@ static void performOptionalDebugActions(
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
                         debugPayloadRootTag, debugTransformRootTag,
-                        transformLibraryFileName, binaryName);
+                        transformLibraryPaths, binaryName);
   });
 
   // Remove temporary attributes if they were set.
@@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryPaths,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     if (failed(
             mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
                              transformLibraryModule->get()->clone())))
-      return failure();
+      return emitError(transformRoot->getLoc(),
+                       "failed to merge library symbols into transform root");
   }
 
   // Step 4
@@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              transformLibraryFileName, binaryName);
+                              transformLibraryPaths, binaryName);
 
   // Step 5
   // ------
@@ -617,53 +619,131 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryPaths,
     std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
     std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsedTransformModule;
-  if (failed(parseTransformModuleFromFile(context, transformFileName,
-                                          parsedTransformModule)))
-    return failure();
-  if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
-    return failure();
+  auto unknownLoc = UnknownLoc::get(context);
 
-  OwningOpRef<ModuleOp> parsedLibraryModule;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibraryModule)))
-    return failure();
-  if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
-    return failure();
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile)))
+      return emitError(loc) << "failed to parse transform module";
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
+      return emitError(loc) << "failed to verify transform module";
+  }
 
-  if (parsedTransformModule) {
-    sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        std::move(parsedTransformModule));
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  for (const std::string &path : transformLibraryPaths) {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+
+    if (llvm::sys::fs::is_regular_file(path)) {
+      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
+      libraryFileNames.push_back(path);
+      continue;
+    }
+
+    if (!llvm::sys::fs::is_directory(path)) {
+      return emitError(loc)
+             << "'" << path << "' is neither a file nor a directory";
+    }
+
+    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
+
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
+                          << "'\n");
+        continue;
+      }
+
+      if (!StringRef(fileName).endswith(".mlir")) {
+        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
+                          << "' because it does not end with '.mlir'\n");
+        continue;
+      }
+
+      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
+      libraryFileNames.push_back(fileName);
+    }
+
+    if (ec)
+      return emitError(loc) << "error while opening files in '" << path
+                            << "': " << ec.message();
+  }
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, libraryFileName,
+                                            parsedLibrary)))
+      return emitError(loc) << "failed to parse transform library module";
+    if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+      return emitError(loc) << "failed to verify transform library module";
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Build shared transform module.
+  if (moduleFromFile) {
+    sharedTransformModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
   } else if (moduleBuilder) {
-    // TODO: better location story.
-    auto location = UnknownLoc::get(context);
+    auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
     auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        ModuleOp::create(location, "__transform"));
+        ModuleOp::create(unknownLoc, "__transform"));
 
     OpBuilder b(context);
     b.setInsertionPointToEnd(localModule->get().getBody());
-    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+    if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
       if (failed(*result))
-        return failure();
+        return (*localModule)->emitError()
+               << "failed to create shared transform module";
       sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibraryModule || !*parsedLibraryModule)
+  if (parsedLibraries.empty())
     return success();
 
+  // Merge parsed libraries into one module.
+  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(loc, "__transform");
+  {
+    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                         UnitAttr::get(context));
+    IRRewriter rewriter(context);
+    // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
+    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+      if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
+                                  std::move(parsedLibrary))))
+        return mergedParsedLibraries->emitError()
+               << "failed to verify merged transform module";
+    }
+  }
+
+  // Use parsed libaries to resolve symbols in shared transform module or return
+  // as separate library module.
   if (sharedTransformModule && *sharedTransformModule) {
     if (failed(mergeSymbolsInto(sharedTransformModule->get(),
-                                std::move(parsedLibraryModule))))
-      return failure();
+                                std::move(mergedParsedLibraries))))
+      return (*sharedTransformModule)->emitError()
+             << "failed to merge symbols from library files "
+                "into shared transform module";
   } else {
-    transformLibraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(mergedParsedLibraries));
   }
   return success();
 }
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..c25212fbe98782b 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,7 @@
 
 // RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
 
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
 // RUN:   -test-transform-dialect-erase-schedule -cse \
 // RUN: | FileCheck %s
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index dd8d141e994da0e..2c4812bf32b0f03 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
 // The external transform script has a declaration to the named sequence @foo,
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
new file mode 100644
index 000000000000000..8b8254976e9aeec
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir,%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK: transform.named_sequence @print_message
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+  transform.named_sequence @reference_other_module(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @reference_other_module failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index 7452deb39b6c18d..c1bd071dc138d56 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
 // The definition of the @print_message named sequence is provided in another file. It
@@ -8,6 +8,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
   transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
+  // expected-error @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
     include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -32,6 +33,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
   transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
 
+  // expected-error @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -47,6 +49,7 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 
+  // expected-error @below {{failed to merge library symbols into transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 7d0837abebde32c..339e62072cd5510 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @print_message named sequence is provided in another
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
similarity index 100%
rename from mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
rename to mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
diff --git a/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
new file mode 100644
index 000000000000000..b3d076f4698495f
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+  transform.named_sequence @reference_other_module(%arg0: !transform.any_op) {
+    transform.include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..3914dc6198a69c2 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..bb0f1125fd39716 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 578d9abe4a56ecb..c60b21c918338b4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,7 @@ class TestTransformDialectInterpreterPass
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
-            transformFileName, transformLibraryFileName, debugPayloadRootTag,
+            transformFileName, transformLibraryPaths, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -216,9 +216,9 @@ class TestTransformDialectInterpreterPass
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
-  Option<std::string> transformLibraryFileName{
-      *this, "transform-library-file-name", llvm::cl::init(""),
-      llvm::cl::desc("Optional name of a file with a module that should be "
+  ListOption<std::string> transformLibraryPaths{
+      *this, "transform-library-paths", llvm::cl::ZeroOrMore,
+      llvm::cl::desc("Optional paths to files with modules that should be "
                      "merged into the transform module to provide the "
                      "definitions of external named sequences.")};
 

>From 720898c581db279463497a19c5591e2a6e4c0b4b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 6 Oct 2023 07:53:43 +0000
Subject: [PATCH 2/2] Address comments from @ftynse's review.

---
 .../TransformInterpreterPassBase.cpp          | 67 ++++++++++++-------
 1 file changed, 41 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 765061b2a6532fe..764d7e25854206e 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -617,34 +617,20 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
                          extraMappings, options);
 }
 
-LogicalResult transform::detail::interpreterBaseInitializeImpl(
-    MLIRContext *context, StringRef transformFileName,
-    ArrayRef<std::string> transformLibraryPaths,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
-    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
-        moduleBuilder) {
-  auto unknownLoc = UnknownLoc::get(context);
-
-  // Parse module from file.
-  OwningOpRef<ModuleOp> moduleFromFile;
-  {
-    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
-    if (failed(parseTransformModuleFromFile(context, transformFileName,
-                                            moduleFromFile)))
-      return emitError(loc) << "failed to parse transform module";
-    if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
-      return emitError(loc) << "failed to verify transform module";
-  }
-
-  // Assemble list of library files.
-  SmallVector<std::string> libraryFileNames;
-  for (const std::string &path : transformLibraryPaths) {
-    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+static LogicalResult
+expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
+                       SmallVectorImpl<std::string> &fileNames) {
+  for (const std::string &path : paths) {
+    auto loc = FileLineColLoc::get(context, path, 0, 0);
 
     if (llvm::sys::fs::is_regular_file(path)) {
       LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
-      libraryFileNames.push_back(path);
+      fileNames.push_back(path);
       continue;
     }
 
@@ -673,7 +659,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
       }
 
       LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
-      libraryFileNames.push_back(fileName);
+      fileNames.push_back(fileName);
     }
 
     if (ec)
@@ -681,6 +667,35 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
                             << "': " << ec.message();
   }
 
+  return success();
+}
+
+LogicalResult transform::detail::interpreterBaseInitializeImpl(
+    MLIRContext *context, StringRef transformFileName,
+    ArrayRef<std::string> transformLibraryPaths,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
+    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+        moduleBuilder) {
+  auto unknownLoc = UnknownLoc::get(context);
+
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile)))
+      return emitError(loc) << "failed to parse transform module";
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
+      return emitError(loc) << "failed to verify transform module";
+  }
+
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
+                                    libraryFileNames)))
+    return failure();
+
   // Parse modules from library files.
   SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
   for (const std::string &libraryFileName : libraryFileNames) {



More information about the Mlir-commits mailing list