[Mlir-commits] [mlir] [mlir][transform] Allow passing various library files to interpreter. (PR #67120)

Ingo Müller llvmlistbot at llvm.org
Fri Sep 22 06:04:01 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67120

>From 770a245972e57ad08ef62605df6f98aad9aab267 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 12:05:38 +0000
Subject: [PATCH] [mlir][transform] Allow passing various library files to
 interpreter.

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
---
 .../Transforms/TransformInterpreterPassBase.h |  43 ++--
 .../TransformInterpreterPassBase.cpp          | 186 +++++++++++++-----
 .../Dialect/Transform/match_batch_matmul.mlir |   2 +-
 .../Dialect/Transform/match_matmul.mlir       |   2 +-
 .../TestTransformDialectInterpreter.cpp       |  17 +-
 5 files changed, 186 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..5d021da90e9a594 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,8 @@ namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
 LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +49,8 @@ LogicalResult interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -62,9 +64,14 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the name of the file containing
+///   - transformLibraryFileNames: if non-empty, the names of files containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to replace declarations and must be
+///     unique within all files provided by this and the next option.
+///   - transformLibraryDirNames: if non-empty, the name of directories
+///     containing definitions of external symbols referenced in the transform
+///     script. These definitions will be used to replace declarations and must
+///     be unique within all files provided by this and the previous option.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -115,17 +122,30 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
-    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
+#define REQUIRE_PASS_LIST_OPTION(NAME)                                         \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::ListOption<std::string>>,                                      \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_LIST_OPTION(transformLibraryFileNames);
+    REQUIRE_PASS_LIST_OPTION(transformLibraryDirNames);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    StringRef transformLibraryFileName =
-        static_cast<Concrete *>(this)->transformLibraryFileName;
+    ArrayRef<std::string> transformLibraryFileNames =
+        static_cast<Concrete *>(this)->transformLibraryFileNames;
+    ArrayRef<std::string> transformLibraryDirNames =
+        static_cast<Concrete *>(this)->transformLibraryDirNames;
     return detail::interpreterBaseInitializeImpl(
-        context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule,
+        context, transformFileName, transformLibraryFileNames,
+        transformLibraryDirNames, sharedTransformModule, transformLibraryModule,
         [this](OpBuilder &builder, Location loc) {
           return static_cast<Concrete *>(this)->constructTransformModule(
               builder, loc);
@@ -159,8 +179,9 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
             op, pass->getArgument(), sharedTransformModule,
             transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->transformLibraryFileName, pass->debugPayloadRootTag,
-            pass->debugTransformRootTag, binaryName)) ||
+            pass->transformLibraryFileNames, pass->transformLibraryDirNames,
+            pass->debugPayloadRootTag, pass->debugTransformRootTag,
+            binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
     }
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..4beb8f31514a1c9 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -161,15 +161,24 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
-               const Pass::Option<std::string> &transformLibraryFileName,
+               const Pass::ListOption<std::string> &transformLibraryFileNames,
+               const Pass::ListOption<std::string> &transformLibraryDirNames,
                StringRef binaryName) {
-  std::string transformLibraryOption = "";
-  if (!transformLibraryFileName.empty()) {
-    transformLibraryOption =
-        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
-                      transformLibraryFileName.getValue())
-            .str();
+  std::string transformLibraryOptions = "";
+  {
+    llvm::raw_string_ostream optionStream(transformLibraryOptions);
+    if (!transformLibraryFileNames.empty()) {
+      optionStream << " " << transformLibraryFileNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryFileNames, optionStream, ",");
+      optionStream << "'";
+    }
+    if (!transformLibraryDirNames.empty()) {
+      optionStream << " " << transformLibraryDirNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryDirNames, optionStream, ",");
+      optionStream << "'";
+    }
   }
+
   os << llvm::formatv(
       "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
@@ -180,7 +189,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      transformLibraryOption, binaryName);
+      transformLibraryOptions, binaryName);
   return os;
 }
 
@@ -200,7 +209,8 @@ void saveReproToTempFile(
     llvm::raw_ostream &os, Operation *target, Operation *transform,
     StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
@@ -227,7 +237,8 @@ void saveReproToTempFile(
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
                  debugPayloadRootTag, debugTransformRootTag,
-                 transformLibraryFileName, binaryName)
+                 transformLibraryFileNames, transformLibraryDirNames,
+                 binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -238,7 +249,8 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -279,10 +291,10 @@ static void performOptionalDebugActions(
 
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
-    printReproCall(llvm::dbgs() << "cat <<EOF | ",
-                   root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag,
-                   transformLibraryFileName, binaryName)
+    printReproCall(
+        llvm::dbgs() << "cat <<EOF | ", root->getName().getStringRef(),
+        passName, debugPayloadRootTag, debugTransformRootTag,
+        transformLibraryFileNames, transformLibraryDirNames, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -292,7 +304,8 @@ static void performOptionalDebugActions(
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
                         debugPayloadRootTag, debugTransformRootTag,
-                        transformLibraryFileName, binaryName);
+                        transformLibraryFileNames, transformLibraryDirNames,
+                        binaryName);
   });
 
   // Remove temporary attributes if they were set.
@@ -383,7 +396,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -449,7 +463,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              transformLibraryFileName, binaryName);
+                              transformLibraryFileNames,
+                              transformLibraryDirNames, binaryName);
 
   // Step 5
   // ------
@@ -460,51 +475,132 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
-    return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
-    return failure();
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile))) {
+      emitError(loc, "failed to parse transform module");
+      return failure();
+    }
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) {
+      emitError(loc, "failed to verify transform module");
+      return failure();
+    }
+  }
 
-  OwningOpRef<ModuleOp> parsedLibrary;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
-    return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
-    return failure();
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  libraryFileNames.append(transformLibraryFileNames.begin(),
+                          transformLibraryFileNames.end());
+
+  for (const std::string &dirName : transformLibraryDirNames) {
+    DBGS() << "Opening files in '" << dirName << "':\n";
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(dirName, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        DBGS() << "  Skipping non-regular file '" << fileName << "'\n";
+        continue;
+      }
+
+      if (!StringRef(fileName).endswith(".mlir")) {
+        DBGS() << "  Skipping '" << fileName
+               << "' because it does not end with '.mlir'\n";
+        continue;
+      }
+
+      DBGS() << "  Adding '" << fileName << "' to list of files\n";
+      libraryFileNames.push_back(fileName);
+    }
+
+    if (ec) {
+      auto loc = UnknownLoc::get(context);
+      emitError(loc, Twine("error while opening files in '") + dirName +
+                         "': " + ec.message());
+      return failure();
+    }
+  }
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, libraryFileName,
+                                            parsedLibrary))) {
+      emitError(loc, "failed to parse transform library module");
+      return failure();
+    }
+    if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) {
+      emitError(loc, "failed to verify transform library module");
+      return failure();
+    }
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Build shared transform module.
+  if (moduleFromFile) {
+    sharedTransformModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
   } else if (moduleBuilder) {
     // TODO: better location story.
-    auto location = UnknownLoc::get(context);
+    auto loc = UnknownLoc::get(context);
     auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        ModuleOp::create(location, "__transform"));
+        ModuleOp::create(loc, "__transform"));
 
     OpBuilder b(context);
     b.setInsertionPointToEnd(localModule->get().getBody());
-    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
-      if (failed(*result))
+    if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
+      if (failed(*result)) {
+        emitError(loc, "failed to create shared transform module");
         return failure();
-      module = std::move(localModule);
+      }
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (parsedLibraries.empty())
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+  // Merge parsed libraries into one module.
+  // TODO: better location story.
+  auto loc = UnknownLoc::get(context);
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(loc, "__transform");
+  mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                       UnitAttr::get(context));
+
+  IRRewriter rewriter(context);
+  rewriter.setInsertionPointToEnd(mergedParsedLibraries->getBody());
+  for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+    rewriter.inlineBlockBefore(parsedLibrary->getBody(),
+                               mergedParsedLibraries->getBody(),
+                               mergedParsedLibraries->getBody()->end());
+    if (failed(mlir::verify(*mergedParsedLibraries))) {
+      emitError(loc, "failed to verify merged transform module");
+      return failure();
+    }
+  }
+
+  // Merge parsed libraries into shared module or return as library module.
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     mergedParsedLibraries.get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(mergedParsedLibraries));
   }
   return success();
 }
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..c7c055c7567e6f3 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..3e06f2ca0895841 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..28e5256385efaea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,8 @@ class TestTransformDialectInterpreterPass
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
-            transformFileName, transformLibraryFileName, debugPayloadRootTag,
+            transformFileName, transformLibraryFileNames,
+            transformLibraryDirNames, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -216,12 +217,16 @@ class TestTransformDialectInterpreterPass
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
-  Option<std::string> transformLibraryFileName{
-      *this, "transform-library-file-name", llvm::cl::init(""),
+  ListOption<std::string> transformLibraryFileNames{
+      *this, "transform-library-file-names", llvm::cl::ZeroOrMore,
+      llvm::cl::desc("Optional filenames containing transform dialect symbol "
+                     "definitions to be injected into the transform module.")};
+  ListOption<std::string> transformLibraryDirNames{
+      *this, "transform-library-dir-names", llvm::cl::ZeroOrMore,
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
-
+          "Optional directories containing transform dialect symbol "
+          "definitions to be injected into the transform module. All '.mlir' "
+          "files rooted under this directory will be loaded.")};
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),
       llvm::cl::desc("test the generation of the transform module during pass "



More information about the Mlir-commits mailing list