[Mlir-commits] [mlir] [mlir] WIP: preload transform libraries before the pass (PR #68000)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Oct 2 08:41:04 PDT 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/68000

Loading in a pass happens at pass initialization time, and may be repeated if the pass is called more than once. Consider a different mechanism based on extra data owned by the transform dialect. This requires changes to mlir-opt (and equivalent downstream setups), but there is precedent with IRDL.

>From 734ff68a1fcb6c877ea19f94aa4ce1c7b420cbc8 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Mon, 2 Oct 2023 15:38:03 +0000
Subject: [PATCH] [mlir] WIP: preload transform libraries before the pass

Loading in a pass happens at pass initialization time, and may be
repeated if the pass is called more than once. Consider a different
mechanism based on extra data owned by the transform dialect. This
requires changes to mlir-opt (and equivalent downstream setups), but
there is precedent with IRDL.
---
 .../Dialect/Transform/IR/TransformDialect.h   | 24 +++++++++++
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 +++++
 .../Dialect/Transform/IR/TransformDialect.cpp | 41 +++++++++++++++++++
 .../TransformInterpreterPassBase.cpp          | 33 ++++++++++-----
 mlir/lib/Tools/mlir-opt/CMakeLists.txt        |  1 +
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       | 19 +++++++++
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir |  3 +-
 .../Dialect/Transform/match_batch_matmul.mlir |  2 +-
 .../Dialect/Transform/match_matmul.mlir       |  2 +-
 9 files changed, 121 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index db27f2c6fc49b75..45758a7888360e9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
 
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectRegistry.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/TypeID.h"
@@ -315,7 +316,30 @@ class BuildOnly : public DerivedTy {
   BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
 };
 
+class TransformLibraries : public TransformDialectData<TransformLibraries> {
+public:
+  explicit TransformLibraries(MLIRContext *ctx) : TransformDialectData(ctx) {}
+
+  void parseAndAddLibrary(StringRef filename);
+  void addLibrary(OwningOpRef<ModuleOp> &&library);
+
+  auto getLibraries() const {
+    return llvm::make_range(libraries.begin(), libraries.end());
+  }
+
+  bool isOk() const { return !hadFailures; }
+
+private:
+  SmallVector<OwningOpRef<ModuleOp>> libraries;
+  bool hadFailures = false;
+};
+
+void registerTransformLibraryPreloader(DialectRegistry &registry,
+                                       StringRef filename);
+
 } // namespace transform
 } // namespace mlir
 
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::TransformLibraries)
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 222a51e8db77eac..32a409fed534e5d 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -90,6 +90,13 @@ class MlirOptMainConfig {
   }
   StringRef getIrdlFile() const { return irdlFileFlag; }
 
+  /// Set the transform dialect library to preload.
+  MlirOptMainConfig &setTransformLibrary(StringRef file) {
+    transformLibraryFlag = file;
+    return *this;
+  }
+  StringRef getTransformLibrary() const { return transformLibraryFlag; }
+
   /// Set the bytecode version to emit.
   MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) {
     emitBytecodeVersion = version;
@@ -222,6 +229,9 @@ class MlirOptMainConfig {
 
   /// Verify that the input IR round-trips perfectly.
   bool verifyRoundtripFlag = false;
+
+  /// Transform dialect library to preload.
+  std::string transformLibraryFlag = "";
 };
 
 /// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 32c56e903268f74..1f002c890490a30 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/ADT/SCCIterator.h"
 
 using namespace mlir;
@@ -172,3 +174,42 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
   return emitError(op->getLoc())
          << "unknown attribute: " << attribute.getName();
 }
+
+void transform::TransformLibraries::parseAndAddLibrary(StringRef filename) {
+  if (filename.empty())
+    return;
+
+  ParserConfig config(getContext());
+  addLibrary(parseSourceFile<ModuleOp>(filename, config));
+  if (!libraries.back()) {
+    hadFailures = true;
+    libraries.pop_back();
+  }
+}
+
+void transform::TransformLibraries::addLibrary(
+    OwningOpRef<ModuleOp> &&library) {
+  libraries.push_back(std::move(library));
+}
+
+namespace {
+class LibraryPreloaderExtension
+    : public transform::TransformDialectExtension<LibraryPreloaderExtension> {
+public:
+  explicit LibraryPreloaderExtension(StringRef filename)
+      : TransformDialectExtension(/*buildOnly=*/false) {
+    std::string ownedFilename = filename.str();
+    addDialectDataInitializer<transform::TransformLibraries>(
+        [ownedFilename](transform::TransformLibraries &libraries) {
+          libraries.parseAndAddLibrary(ownedFilename);
+        });
+  }
+};
+} // namespace
+
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::TransformLibraries)
+
+void mlir::transform::registerTransformLibraryPreloader(
+    DialectRegistry &registry, StringRef filename) {
+  registry.addExtension(std::make_unique<LibraryPreloaderExtension>(filename));
+}
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index b74e14d377c036f..603a4e86e5b6e01 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Parser/Parser.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
@@ -465,19 +466,27 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
+
+  const auto &libraries =
+      context->getLoadedDialect<transform::TransformDialect>()
+          ->getExtraData<transform::TransformLibraries>();
+  if (!libraries.isOk()) {
+    return emitError(UnknownLoc::get(context),
+                     "transform interpreter disabled due to earlier errors in "
+                     "preloading transform libraries");
+  }
+  auto libraryRange = libraries.getLibraries();
+  if (llvm::range_size(libraryRange) > 1) {
+    return emitError(UnknownLoc::get(context),
+                     "multiple library files not yet supported");
+  }
+
   OwningOpRef<ModuleOp> parsed;
   if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
     return failure();
   if (parsed && failed(mlir::verify(*parsed)))
     return failure();
 
-  OwningOpRef<ModuleOp> parsedLibrary;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
-    return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
-    return failure();
-
   if (parsed) {
     module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
   } else if (moduleBuilder) {
@@ -495,16 +504,18 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (libraryRange.empty())
     return success();
 
   if (module && *module) {
     if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+                                     libraryRange.begin()->get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    // TODO: we shouldn't clone or transfer ownership here, the dialect
+    // extension should be able to keep owning this instead.
+    libraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        libraryRange.begin()->get().clone());
   }
   return success();
 }
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
index f24d4c60174eeca..11ff1da1ec1b7e1 100644
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_library(MLIROptLib
   MLIRPluginsLib
   MLIRSupport
   MLIRIRDL
+  MLIRTransformDialect
   )
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 644113058bdc1cc..c195f59c7527b11 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Debug/Observers/ActionLogging.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -33,6 +34,7 @@
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
+#include "mlir/Support/TypeID.h"
 #include "mlir/Tools/ParseUtilities.h"
 #include "mlir/Tools/Plugins/DialectPlugin.h"
 #include "mlir/Tools/Plugins/PassPlugin.h"
@@ -144,6 +146,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
         cl::location(verifyRoundtripFlag), cl::init(false));
 
+    static cl::opt<std::string, /*ExternalStorage=*/true> transformLibrary(
+        "transform-library",
+        cl::desc("Library of transform dialect symbols to preload"),
+        cl::location(transformLibraryFlag), cl::init(""),
+        cl::value_desc("filename"));
+
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
     /// Set the callback to load a pass plugin.
@@ -424,6 +432,17 @@ static LogicalResult processBuffer(raw_ostream &os,
   tracing::InstallDebugHandler installDebugHandler(context,
                                                    config.getDebugConfig());
 
+  // Preload the transform library if requested. This should happen after
+  // the debug handler configuration.
+  {
+    DialectRegistry localRegistry;
+    transform::registerTransformLibraryPreloader(localRegistry,
+                                                 config.getTransformLibrary());
+    context.appendDialectRegistry(localRegistry);
+    // Force extension loading.
+    context.loadDialect<transform::TransformDialect>();
+  }
+
   // If we are in verify diagnostics mode then we have a lot of work to do,
   // otherwise just perform the actions without worrying about it.
   if (!config.shouldVerifyDiagnostics()) {
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..a86582daa112aa0 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,8 @@
 
 // RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
 
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="debug-payload-root-tag=payload" \
+// RUN:   -transform-library=%p/lower-to-llvm-transform-symbol-def.mlir \
 // RUN:   -test-transform-dialect-erase-schedule -cse \
 // RUN: | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..6c81c61110f35e4 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --transform-library=%p/match_matmul_common.mlir --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..5f8f93954e36cce 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --transform-library=%p/match_matmul_common.mlir --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(



More information about the Mlir-commits mailing list