[Mlir-commits] [mlir] [mlir][Transform] Create a transform interpreter and a preloader pass (PR #68661)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 9 22:00:12 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
This revision provides the ability to use an arbitrary named sequence op as
the entry point to a transform dialect strategy.
It is also a step towards better transform dialect usage in pass pipelines
that need to preload a transform library rather thanparse it on the fly.
The interpreter itself is significantly simpler than its testing counterpart
by avoiding payload/debug root tags and multiple shared modules.
In the process, the NamedSequenceOp::apply function is adapted to allow it
being an entry point.
NamedSequenceOp is **not** extended to take the PossibleTopLevelTrait at this
time, because the implementation of the trait is specific to allowing one
top-level dangling op with a region such as SequenceOp or AlternativesOp.
In particular, the verifier of PossibleTopLevelTrait does not allow for an
empty body, which is necessary to declare a NamedSequenceOp that gets linked
in separately before application.
In the future, we should dispense with the PossibleTopLevelTrait altogether
and always enter the interpreter with a NamedSequenceOp.
Lastly, relevant TD linking utilities are moved to TransformInterpreterUtils
and reused from there.
---
Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68661.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/Transforms/Passes.td (+31)
- (modified) mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h (+18)
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+7-3)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+14-2)
- (modified) mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt (+2)
- (added) mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp (+55)
- (added) mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp (+41)
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (-53)
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp (+114-11)
- (modified) mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir (+9-10)
- (renamed) mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir (+14-5)
- (modified) utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
index 2400066c8ad8c8b..45a26496335850a 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
@@ -42,4 +42,35 @@ def InferEffectsPass : Pass<"transform-infer-effects"> {
}];
}
+def PreloadLibraryPass : Pass<"transform-preload-library"> {
+ let summary = "preload transform dialect library";
+ let description = [{
+ This pass preloads a transform library and makes it available to a subsequent
+ transform interpreter passes. The preloading occurs into the Transform
+ dialect and thus provides very limited functionality that does not scale.
+
+ Warning: Only a single such pass should exist for a given MLIR context.
+ This is a temporary solution until a resource-based solution is available.
+ TODO: use a resource blob.
+ }];
+ let options = [
+ ListOption<"transformLibraryPaths", "transform-library-paths", "std::string",
+ "Optional paths to files with modules that should be merged into the "
+ "transform module to provide the definitions of external named sequences.">
+ ];
+}
+
+def InterpreterPass : Pass<"transform-interpreter"> {
+ let summary = "transform dialect interpreter";
+ let description = [{
+ This pass runs the transform dialect interpreter and applies the named
+ sequence transformation specified by the provided name (defaults to
+ `__transform_main`).
+ }];
+ let options = [
+ Option<"entryPoint", "entry-point", "std::string",
+ /*default=*/[{"__transform_main"}],
+ "Entry point of the pass pipeline.">,
+ ];
+}
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 36c80e6fd61d3c1..3fc02267f26e9da 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -26,6 +26,16 @@ class Region;
namespace transform {
namespace detail {
+
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
+ MLIRContext *context,
+ SmallVectorImpl<std::string> &fileNames);
+
/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
@@ -33,6 +43,14 @@ parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);
+/// Utility to parse, verify, aggregate and link the content of all mlir files
+/// nested under `transformLibraryPaths` and containing transform dialect
+/// specifications.
+LogicalResult
+assembleTransformLibraryFromPaths(MLIRContext *context,
+ ArrayRef<std::string> transformLibraryPaths,
+ OwningOpRef<ModuleOp> &transformModule);
+
/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4f88b8522e54c80..91d8302150808d0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -60,9 +60,13 @@ ArrayRef<Operation *>
transform::TransformState::getPayloadOpsView(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
- assert(
- iter != operationMapping.end() &&
- "cannot find mapping for payload handle (param/value handle provided?)");
+
+ if (iter == operationMapping.end()) {
+ value.dump();
+ assert(false &&
+ "cannot find mapping for payload handle (param/value handle "
+ "provided?)");
+ }
return iter->getSecond();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0e20b379cc2a3e7..dcbaaf04c49490d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1761,8 +1761,20 @@ DiagnosedSilenceableFailure
transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- // Nothing to do here.
- return DiagnosedSilenceableFailure::success();
+ if (isExternal())
+ return emitDefiniteFailure() << "unresolved external named sequence";
+
+ // Map the entry block argument to the list of operations.
+ // Note: this is the same implementation as PossibleTopLevelTransformOp but
+ // without attaching the interface / trait since that is tailored to a
+ // dangling top-level op that does not get "called".
+ auto scope = state.make_region_scope(getBody());
+ if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
+ state, this->getOperation(), getBody())))
+ return DiagnosedSilenceableFailure::definiteFailure();
+
+ return applySequenceBlock(getBody().front(),
+ FailurePropagationMode::Propagate, state, results);
}
void transform::NamedSequenceOp::getEffects(
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 8774a8b86fb0d91..f0f57874f5e7032 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
+ InterpreterPass.cpp
+ PreloadLibraryPass.cpp
TransformInterpreterPassBase.cpp
TransformInterpreterUtils.cpp
diff --git a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
new file mode 100644
index 000000000000000..4a6921b8611bf46
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
@@ -0,0 +1,55 @@
+//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_INTERPRETERPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+namespace {
+class InterpreterPass
+ : public transform::impl::InterpreterPassBase<InterpreterPass> {
+public:
+ using Base::Base;
+
+ LogicalResult initialize(MLIRContext *context) override {
+ // TODO: use a resource blob.
+ ModuleOp transformModule =
+ transform::detail::getPreloadedTransformModule(context);
+ if (transformModule) {
+ sharedTransformModule =
+ std::make_shared<OwningOpRef<ModuleOp>>(transformModule.clone());
+ }
+ return success();
+ }
+
+ void runOnOperation() override {
+ if (failed(transform::applyTransformNamedSequence(
+ getOperation(), sharedTransformModule->get(), options.enableExpensiveChecks(true), entryPoint)))
+ return signalPassFailure();
+ }
+
+private:
+ /// Transform interpreter options.
+ transform::TransformOptions options;
+
+ /// The separate transform module to be used for transformations, shared
+ /// across multiple instances of the pass if it is applied in parallel to
+ /// avoid potentially expensive cloning. MUST NOT be modified after the pass
+ /// has been initialized.
+ std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
new file mode 100644
index 000000000000000..795e31637617fe4
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
@@ -0,0 +1,41 @@
+//===- PreloadLibraryPass.cpp - Pass to preload a transform library -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_PRELOADLIBRARYPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+namespace {
+class PreloadLibraryPass
+ : public transform::impl::PreloadLibraryPassBase<PreloadLibraryPass> {
+public:
+ using Base::Base;
+
+ LogicalResult initialize(MLIRContext *context) override {
+ OwningOpRef<ModuleOp> mergedParsedLibraries;
+ if (failed(transform::detail::assembleTransformLibraryFromPaths(
+ context, transformLibraryPaths, mergedParsedLibraries)))
+ return failure();
+ // TODO: use a resource blob.
+ auto *dialect = context->getOrLoadDialect<transform::TransformDialect>();
+ dialect->registerLibraryModule(std::move(mergedParsedLibraries));
+ return success();
+ }
+
+ void runOnOperation() override {}
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 5f35b6789dc94fe..538c81fe39fddb2 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -357,59 +357,6 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
extraMappings, options);
}
-/// Expands the given list of `paths` to a list of `.mlir` files.
-///
-/// Each entry in `paths` may either be a regular file, in which case it ends up
-/// in the result list, or a directory, in which case all (regular) `.mlir`
-/// files in that directory are added. Any other file types lead to a failure.
-static LogicalResult
-expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
- SmallVectorImpl<std::string> &fileNames) {
- for (const std::string &path : paths) {
- auto loc = FileLineColLoc::get(context, path, 0, 0);
-
- if (llvm::sys::fs::is_regular_file(path)) {
- LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
- fileNames.push_back(path);
- continue;
- }
-
- if (!llvm::sys::fs::is_directory(path)) {
- return emitError(loc)
- << "'" << path << "' is neither a file nor a directory";
- }
-
- LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
-
- std::error_code ec;
- for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
- it != itEnd && !ec; it.increment(ec)) {
- const std::string &fileName = it->path();
-
- if (it->type() != llvm::sys::fs::file_type::regular_file) {
- LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
- << "'\n");
- continue;
- }
-
- if (!StringRef(fileName).endswith(".mlir")) {
- LLVM_DEBUG(DBGS() << " Skipping '" << fileName
- << "' because it does not end with '.mlir'\n");
- continue;
- }
-
- LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
- fileNames.push_back(fileName);
- }
-
- if (ec)
- return emitError(loc) << "error while opening files in '" << path
- << "': " << ec.message();
- }
-
- return success();
-}
-
LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
ArrayRef<std::string> transformLibraryPaths,
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 1a6ebdd16232e8a..eb6f83dc8330563 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -23,6 +23,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@@ -31,6 +32,59 @@ using namespace mlir;
#define DEBUG_TYPE "transform-dialect-interpreter-utils"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult transform::detail::expandPathsToMLIRFiles(
+ ArrayRef<std::string> &paths, MLIRContext *context,
+ SmallVectorImpl<std::string> &fileNames) {
+ for (const std::string &path : paths) {
+ auto loc = FileLineColLoc::get(context, path, 0, 0);
+
+ if (llvm::sys::fs::is_regular_file(path)) {
+ LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
+ fileNames.push_back(path);
+ continue;
+ }
+
+ if (!llvm::sys::fs::is_directory(path)) {
+ return emitError(loc)
+ << "'" << path << "' is neither a file nor a directory";
+ }
+
+ LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
+
+ std::error_code ec;
+ for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
+ it != itEnd && !ec; it.increment(ec)) {
+ const std::string &fileName = it->path();
+
+ if (it->type() != llvm::sys::fs::file_type::regular_file) {
+ LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
+ << "'\n");
+ continue;
+ }
+
+ if (!StringRef(fileName).endswith(".mlir")) {
+ LLVM_DEBUG(DBGS() << " Skipping '" << fileName
+ << "' because it does not end with '.mlir'\n");
+ continue;
+ }
+
+ LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
+ fileNames.push_back(fileName);
+ }
+
+ if (ec)
+ return emitError(loc) << "error while opening files in '" << path
+ << "': " << ec.message();
+ }
+
+ return success();
+}
+
LogicalResult transform::detail::parseTransformModuleFromFile(
MLIRContext *context, llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule) {
@@ -91,10 +145,51 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
return nullptr;
}
-/// Return whether `func1` can be merged into `func2`. For that to work `func1`
-/// has to be a declaration (aka has to be external) and `func2` either has to
-/// be a declaration as well, or it has to be public (otherwise, it wouldn't
-/// be visible by `func1`).
+LogicalResult transform::detail::assembleTransformLibraryFromPaths(
+ MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
+ OwningOpRef<ModuleOp> &transformModule) {
+ // Assemble list of library files.
+ SmallVector<std::string> libraryFileNames;
+ if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
+ libraryFileNames)))
+ return failure();
+
+ // Parse modules from library files.
+ SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+ for (const std::string &libraryFileName : libraryFileNames) {
+ OwningOpRef<ModuleOp> parsedLibrary;
+ auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+ if (failed(transform::detail::parseTransformModuleFromFile(
+ context, libraryFileName, parsedLibrary)))
+ return emitError(loc) << "failed to parse transform library module";
+ parsedLibraries.push_back(std::move(parsedLibrary));
+ }
+
+ // Merge parsed libraries into one module.
+ auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
+ OwningOpRef<ModuleOp> mergedParsedLibraries =
+ ModuleOp::create(loc, "__transform");
+ {
+ mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+ UnitAttr::get(context));
+ IRRewriter rewriter(context);
+ // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
+ for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+ if (failed(transform::detail::mergeSymbolsInto(
+ mergedParsedLibraries.get(), std::move(parsedLibrary))))
+ return mergedParsedLibraries->emitError()
+ << "failed to verify merged transform module";
+ }
+ }
+
+ transformModule = std::move(mergedParsedLibraries);
+ return success();
+}
+
+/// Return whether `func1` can be merged into `func2`. For that to work
+/// `func1` has to be a declaration (aka has to be external) and `func2`
+/// either has to be a declaration as well, or it has to be public (otherwise,
+/// it wouldn't be visible by `func1`).
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
}
@@ -281,8 +376,9 @@ transform::detail::mergeSymbolsInto(Operation *target,
auto collidingFuncOp =
cast<FunctionOpInterface>(collidingOp.getOperation());
- // Both ops are in the target module now and can be treated symmetrically,
- // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+ // Both ops are in the target module now and can be treated
+ // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
+ // `collidingFuncOp`.
if (!canMergeInto(funcOp, collidingFuncOp)) {
std::swap(funcOp, collidingFuncOp);
}
@@ -317,18 +413,25 @@ LogicalResult transform::applyTransformNamedSequence(
const TransformOptions &options, StringRef entryPoint) {
Operation *transformRoot =
detail::findTransformEntryPoint(payload, transformModule, entryPoint);
- if (!transformRoot)
- return failure();
+ if (!transformRoot) {
+ return payload->emitError()
+ << "could not find transform entry point: " << entryPoint
+ << " in either payload or transform module";
+ }
// `transformModule` may not be modified.
- OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
if (transformModule && !transformModule->isAncestor(transformRoot)) {
+ OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
if (failed(detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
- std::move(clonedTransformModule))))
- return failure();
+ std::move(clonedTransformModule)))) {
+ return payload->emitError() << "failed to merge symbols";
+ }
}
+ LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
+ LLVM_DEBUG(DBGS() << "To\n" << *payload << "\n");
+
// Apply the transform to the IR, do not enforce top-level constraints.
RaggedA...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/68661
More information about the Mlir-commits
mailing list