[Mlir-commits] [mlir] [mlir] WIP: preload transform libraries before the pass (PR #68000)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Oct 2 08:41:04 PDT 2023
https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/68000
Loading in a pass happens at pass initialization time, and may be repeated if the pass is called more than once. Consider a different mechanism based on extra data owned by the transform dialect. This requires changes to mlir-opt (and equivalent downstream setups), but there is precedent with IRDL.
>From 734ff68a1fcb6c877ea19f94aa4ce1c7b420cbc8 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Mon, 2 Oct 2023 15:38:03 +0000
Subject: [PATCH] [mlir] WIP: preload transform libraries before the pass
Loading in a pass happens at pass initialization time, and may be
repeated if the pass is called more than once. Consider a different
mechanism based on extra data owned by the transform dialect. This
requires changes to mlir-opt (and equivalent downstream setups), but
there is precedent with IRDL.
---
.../Dialect/Transform/IR/TransformDialect.h | 24 +++++++++++
.../include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 +++++
.../Dialect/Transform/IR/TransformDialect.cpp | 41 +++++++++++++++++++
.../TransformInterpreterPassBase.cpp | 33 ++++++++++-----
mlir/lib/Tools/mlir-opt/CMakeLists.txt | 1 +
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 19 +++++++++
mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir | 3 +-
.../Dialect/Transform/match_batch_matmul.mlir | 2 +-
.../Dialect/Transform/match_matmul.mlir | 2 +-
9 files changed, 121 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index db27f2c6fc49b75..45758a7888360e9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
@@ -315,7 +316,30 @@ class BuildOnly : public DerivedTy {
BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
};
+class TransformLibraries : public TransformDialectData<TransformLibraries> {
+public:
+ explicit TransformLibraries(MLIRContext *ctx) : TransformDialectData(ctx) {}
+
+ void parseAndAddLibrary(StringRef filename);
+ void addLibrary(OwningOpRef<ModuleOp> &&library);
+
+ auto getLibraries() const {
+ return llvm::make_range(libraries.begin(), libraries.end());
+ }
+
+ bool isOk() const { return !hadFailures; }
+
+private:
+ SmallVector<OwningOpRef<ModuleOp>> libraries;
+ bool hadFailures = false;
+};
+
+void registerTransformLibraryPreloader(DialectRegistry ®istry,
+ StringRef filename);
+
} // namespace transform
} // namespace mlir
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::TransformLibraries)
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 222a51e8db77eac..32a409fed534e5d 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -90,6 +90,13 @@ class MlirOptMainConfig {
}
StringRef getIrdlFile() const { return irdlFileFlag; }
+ /// Set the transform dialect library to preload.
+ MlirOptMainConfig &setTransformLibrary(StringRef file) {
+ transformLibraryFlag = file;
+ return *this;
+ }
+ StringRef getTransformLibrary() const { return transformLibraryFlag; }
+
/// Set the bytecode version to emit.
MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) {
emitBytecodeVersion = version;
@@ -222,6 +229,9 @@ class MlirOptMainConfig {
/// Verify that the input IR round-trips perfectly.
bool verifyRoundtripFlag = false;
+
+ /// Transform dialect library to preload.
+ std::string transformLibraryFlag = "";
};
/// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 32c56e903268f74..1f002c890490a30 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/SCCIterator.h"
using namespace mlir;
@@ -172,3 +174,42 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
return emitError(op->getLoc())
<< "unknown attribute: " << attribute.getName();
}
+
+void transform::TransformLibraries::parseAndAddLibrary(StringRef filename) {
+ if (filename.empty())
+ return;
+
+ ParserConfig config(getContext());
+ addLibrary(parseSourceFile<ModuleOp>(filename, config));
+ if (!libraries.back()) {
+ hadFailures = true;
+ libraries.pop_back();
+ }
+}
+
+void transform::TransformLibraries::addLibrary(
+ OwningOpRef<ModuleOp> &&library) {
+ libraries.push_back(std::move(library));
+}
+
+namespace {
+class LibraryPreloaderExtension
+ : public transform::TransformDialectExtension<LibraryPreloaderExtension> {
+public:
+ explicit LibraryPreloaderExtension(StringRef filename)
+ : TransformDialectExtension(/*buildOnly=*/false) {
+ std::string ownedFilename = filename.str();
+ addDialectDataInitializer<transform::TransformLibraries>(
+ [ownedFilename](transform::TransformLibraries &libraries) {
+ libraries.parseAndAddLibrary(ownedFilename);
+ });
+ }
+};
+} // namespace
+
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::TransformLibraries)
+
+void mlir::transform::registerTransformLibraryPreloader(
+ DialectRegistry ®istry, StringRef filename) {
+ registry.addExtension(std::make_unique<LibraryPreloaderExtension>(filename));
+}
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index b74e14d377c036f..603a4e86e5b6e01 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -21,6 +21,7 @@
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
@@ -465,19 +466,27 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
+
+ const auto &libraries =
+ context->getLoadedDialect<transform::TransformDialect>()
+ ->getExtraData<transform::TransformLibraries>();
+ if (!libraries.isOk()) {
+ return emitError(UnknownLoc::get(context),
+ "transform interpreter disabled due to earlier errors in "
+ "preloading transform libraries");
+ }
+ auto libraryRange = libraries.getLibraries();
+ if (llvm::range_size(libraryRange) > 1) {
+ return emitError(UnknownLoc::get(context),
+ "multiple library files not yet supported");
+ }
+
OwningOpRef<ModuleOp> parsed;
if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
return failure();
if (parsed && failed(mlir::verify(*parsed)))
return failure();
- OwningOpRef<ModuleOp> parsedLibrary;
- if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
- parsedLibrary)))
- return failure();
- if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
- return failure();
-
if (parsed) {
module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
} else if (moduleBuilder) {
@@ -495,16 +504,18 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
}
}
- if (!parsedLibrary || !*parsedLibrary)
+ if (libraryRange.empty())
return success();
if (module && *module) {
if (failed(defineDeclaredSymbols(*module->get().getBody(),
- parsedLibrary.get())))
+ libraryRange.begin()->get())))
return failure();
} else {
- libraryModule =
- std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+ // TODO: we shouldn't clone or transfer ownership here, the dialect
+ // extension should be able to keep owning this instead.
+ libraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+ libraryRange.begin()->get().clone());
}
return success();
}
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
index f24d4c60174eeca..11ff1da1ec1b7e1 100644
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_library(MLIROptLib
MLIRPluginsLib
MLIRSupport
MLIRIRDL
+ MLIRTransformDialect
)
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 644113058bdc1cc..c195f59c7527b11 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -20,6 +20,7 @@
#include "mlir/Debug/Observers/ActionLogging.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLLoading.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -33,6 +34,7 @@
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
+#include "mlir/Support/TypeID.h"
#include "mlir/Tools/ParseUtilities.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
@@ -144,6 +146,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
+ static cl::opt<std::string, /*ExternalStorage=*/true> transformLibrary(
+ "transform-library",
+ cl::desc("Library of transform dialect symbols to preload"),
+ cl::location(transformLibraryFlag), cl::init(""),
+ cl::value_desc("filename"));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
/// Set the callback to load a pass plugin.
@@ -424,6 +432,17 @@ static LogicalResult processBuffer(raw_ostream &os,
tracing::InstallDebugHandler installDebugHandler(context,
config.getDebugConfig());
+ // Preload the transform library if requested. This should happen after
+ // the debug handler configuration.
+ {
+ DialectRegistry localRegistry;
+ transform::registerTransformLibraryPreloader(localRegistry,
+ config.getTransformLibrary());
+ context.appendDialectRegistry(localRegistry);
+ // Force extension loading.
+ context.loadDialect<transform::TransformDialect>();
+ }
+
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
if (!config.shouldVerifyDiagnostics()) {
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..a86582daa112aa0 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,8 @@
// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="debug-payload-root-tag=payload" \
+// RUN: -transform-library=%p/lower-to-llvm-transform-symbol-def.mlir \
// RUN: -test-transform-dialect-erase-schedule -cse \
// RUN: | FileCheck %s
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..6c81c61110f35e4 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --transform-library=%p/match_matmul_common.mlir --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..5f8f93954e36cce 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --transform-library=%p/match_matmul_common.mlir --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
More information about the Mlir-commits
mailing list