[Mlir-commits] [mlir] [mlir][transform] Allow passing various library files to interpreter. (PR #67120)
Ingo Müller
llvmlistbot at llvm.org
Fri Sep 22 06:04:01 PDT 2023
https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67120
>From 770a245972e57ad08ef62605df6f98aad9aab267 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 12:05:38 +0000
Subject: [PATCH] [mlir][transform] Allow passing various library files to
interpreter.
The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
---
.../Transforms/TransformInterpreterPassBase.h | 43 ++--
.../TransformInterpreterPassBase.cpp | 186 +++++++++++++-----
.../Dialect/Transform/match_batch_matmul.mlir | 2 +-
.../Dialect/Transform/match_matmul.mlir | 2 +-
.../TestTransformDialectInterpreter.cpp | 17 +-
5 files changed, 186 insertions(+), 64 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..5d021da90e9a594 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,8 @@ namespace detail {
/// Template-free implementation of TransformInterpreterPassBase::initialize.
LogicalResult interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
- StringRef transformLibraryFileName,
+ ArrayRef<std::string> transformLibraryFileNames,
+ ArrayRef<std::string> transformLibraryDirNames,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +49,8 @@ LogicalResult interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryFileNames,
+ const Pass::ListOption<std::string> &transformLibraryDirNames,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName);
@@ -62,9 +64,14 @@ LogicalResult interpreterBaseRunOnOperationImpl(
/// transform script. If empty, `debugTransformRootTag` is considered or the
/// pass root operation must contain a single top-level transform op that
/// will be interpreted.
-/// - transformLibraryFileName: if non-empty, the name of the file containing
+/// - transformLibraryFileNames: if non-empty, the names of files containing
/// definitions of external symbols referenced in the transform script.
-/// These definitions will be used to replace declarations.
+/// These definitions will be used to replace declarations and must be
+/// unique within all files provided by this and the next option.
+/// - transformLibraryDirNames: if non-empty, the name of directories
+/// containing definitions of external symbols referenced in the transform
+/// script. These definitions will be used to replace declarations and must
+/// be unique within all files provided by this and the previous option.
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
/// `kTransformDialectTagAttrName` indicating the single op that is
/// considered the payload root of the transform interpreter; otherwise, the
@@ -115,17 +122,30 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
REQUIRE_PASS_OPTION(transformFileName);
REQUIRE_PASS_OPTION(debugPayloadRootTag);
REQUIRE_PASS_OPTION(debugTransformRootTag);
- REQUIRE_PASS_OPTION(transformLibraryFileName);
#undef REQUIRE_PASS_OPTION
+#define REQUIRE_PASS_LIST_OPTION(NAME) \
+ static_assert( \
+ std::is_same_v< \
+ std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
+ Pass::ListOption<std::string>>, \
+ "required " #NAME " string pass option is missing")
+
+ REQUIRE_PASS_LIST_OPTION(transformLibraryFileNames);
+ REQUIRE_PASS_LIST_OPTION(transformLibraryDirNames);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
StringRef transformFileName =
static_cast<Concrete *>(this)->transformFileName;
- StringRef transformLibraryFileName =
- static_cast<Concrete *>(this)->transformLibraryFileName;
+ ArrayRef<std::string> transformLibraryFileNames =
+ static_cast<Concrete *>(this)->transformLibraryFileNames;
+ ArrayRef<std::string> transformLibraryDirNames =
+ static_cast<Concrete *>(this)->transformLibraryDirNames;
return detail::interpreterBaseInitializeImpl(
- context, transformFileName, transformLibraryFileName,
- sharedTransformModule, transformLibraryModule,
+ context, transformFileName, transformLibraryFileNames,
+ transformLibraryDirNames, sharedTransformModule, transformLibraryModule,
[this](OpBuilder &builder, Location loc) {
return static_cast<Concrete *>(this)->constructTransformModule(
builder, loc);
@@ -159,8 +179,9 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
op, pass->getArgument(), sharedTransformModule,
transformLibraryModule,
/*extraMappings=*/{}, options, pass->transformFileName,
- pass->transformLibraryFileName, pass->debugPayloadRootTag,
- pass->debugTransformRootTag, binaryName)) ||
+ pass->transformLibraryFileNames, pass->transformLibraryDirNames,
+ pass->debugPayloadRootTag, pass->debugTransformRootTag,
+ binaryName)) ||
failed(pass->runAfterInterpreter(op))) {
return pass->signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..4beb8f31514a1c9 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -161,15 +161,24 @@ static llvm::raw_ostream &
printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryFileNames,
+ const Pass::ListOption<std::string> &transformLibraryDirNames,
StringRef binaryName) {
- std::string transformLibraryOption = "";
- if (!transformLibraryFileName.empty()) {
- transformLibraryOption =
- llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
- transformLibraryFileName.getValue())
- .str();
+ std::string transformLibraryOptions = "";
+ {
+ llvm::raw_string_ostream optionStream(transformLibraryOptions);
+ if (!transformLibraryFileNames.empty()) {
+ optionStream << " " << transformLibraryFileNames.getArgStr() << "='";
+ llvm::interleave(transformLibraryFileNames, optionStream, ",");
+ optionStream << "'";
+ }
+ if (!transformLibraryDirNames.empty()) {
+ optionStream << " " << transformLibraryDirNames.getArgStr() << "='";
+ llvm::interleave(transformLibraryDirNames, optionStream, ",");
+ optionStream << "'";
+ }
}
+
os << llvm::formatv(
"{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
passName, debugPayloadRootTag.getArgStr(),
@@ -180,7 +189,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
debugTransformRootTag.empty()
? StringRef(kTransformDialectTagTransformContainerValue)
: debugTransformRootTag,
- transformLibraryOption, binaryName);
+ transformLibraryOptions, binaryName);
return os;
}
@@ -200,7 +209,8 @@ void saveReproToTempFile(
llvm::raw_ostream &os, Operation *target, Operation *transform,
StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryFileNames,
+ const Pass::ListOption<std::string> &transformLibraryDirNames,
StringRef binaryName) {
using llvm::sys::fs::TempFile;
Operation *root = getRootOperation(target);
@@ -227,7 +237,8 @@ void saveReproToTempFile(
os << "=== Transform Interpreter Repro ===\n";
printReproCall(os, root->getName().getStringRef(), passName,
debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName)
+ transformLibraryFileNames, transformLibraryDirNames,
+ binaryName)
<< " " << filename << "\n";
os << "===================================\n";
}
@@ -238,7 +249,8 @@ static void performOptionalDebugActions(
Operation *target, Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryFileNames,
+ const Pass::ListOption<std::string> &transformLibraryDirNames,
StringRef binaryName) {
MLIRContext *context = target->getContext();
@@ -279,10 +291,10 @@ static void performOptionalDebugActions(
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
- printReproCall(llvm::dbgs() << "cat <<EOF | ",
- root->getName().getStringRef(), passName,
- debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName)
+ printReproCall(
+ llvm::dbgs() << "cat <<EOF | ", root->getName().getStringRef(),
+ passName, debugPayloadRootTag, debugTransformRootTag,
+ transformLibraryFileNames, transformLibraryDirNames, binaryName)
<< "\n";
printModuleForRepro(llvm::dbgs(), root, transform);
llvm::dbgs() << "\nEOF\n";
@@ -292,7 +304,8 @@ static void performOptionalDebugActions(
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName);
+ transformLibraryFileNames, transformLibraryDirNames,
+ binaryName);
});
// Remove temporary attributes if they were set.
@@ -383,7 +396,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryFileNames,
+ const Pass::ListOption<std::string> &transformLibraryDirNames,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName) {
@@ -449,7 +463,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
// repro to stderr and/or a file.
performOptionalDebugActions(target, transformRoot, passName,
debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName);
+ transformLibraryFileNames,
+ transformLibraryDirNames, binaryName);
// Step 5
// ------
@@ -460,51 +475,132 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
- StringRef transformLibraryFileName,
- std::shared_ptr<OwningOpRef<ModuleOp>> &module,
- std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+ ArrayRef<std::string> transformLibraryFileNames,
+ ArrayRef<std::string> transformLibraryDirNames,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
- OwningOpRef<ModuleOp> parsed;
- if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
- return failure();
- if (parsed && failed(mlir::verify(*parsed)))
- return failure();
+ // Parse module from file.
+ OwningOpRef<ModuleOp> moduleFromFile;
+ {
+ auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+ if (failed(parseTransformModuleFromFile(context, transformFileName,
+ moduleFromFile))) {
+ emitError(loc, "failed to parse transform module");
+ return failure();
+ }
+ if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) {
+ emitError(loc, "failed to verify transform module");
+ return failure();
+ }
+ }
- OwningOpRef<ModuleOp> parsedLibrary;
- if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
- parsedLibrary)))
- return failure();
- if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
- return failure();
+ // Assemble list of library files.
+ SmallVector<std::string> libraryFileNames;
+ libraryFileNames.append(transformLibraryFileNames.begin(),
+ transformLibraryFileNames.end());
+
+ for (const std::string &dirName : transformLibraryDirNames) {
+ DBGS() << "Opening files in '" << dirName << "':\n";
- if (parsed) {
- module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+ std::error_code ec;
+ for (llvm::sys::fs::directory_iterator it(dirName, ec), itEnd;
+ it != itEnd && !ec; it.increment(ec)) {
+ const std::string &fileName = it->path();
+
+ if (it->type() != llvm::sys::fs::file_type::regular_file) {
+ DBGS() << " Skipping non-regular file '" << fileName << "'\n";
+ continue;
+ }
+
+ if (!StringRef(fileName).endswith(".mlir")) {
+ DBGS() << " Skipping '" << fileName
+ << "' because it does not end with '.mlir'\n";
+ continue;
+ }
+
+ DBGS() << " Adding '" << fileName << "' to list of files\n";
+ libraryFileNames.push_back(fileName);
+ }
+
+ if (ec) {
+ auto loc = UnknownLoc::get(context);
+ emitError(loc, Twine("error while opening files in '") + dirName +
+ "': " + ec.message());
+ return failure();
+ }
+ }
+
+ // Parse modules from library files.
+ SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+ for (const std::string &libraryFileName : libraryFileNames) {
+ OwningOpRef<ModuleOp> parsedLibrary;
+ auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+ if (failed(parseTransformModuleFromFile(context, libraryFileName,
+ parsedLibrary))) {
+ emitError(loc, "failed to parse transform library module");
+ return failure();
+ }
+ if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) {
+ emitError(loc, "failed to verify transform library module");
+ return failure();
+ }
+ parsedLibraries.push_back(std::move(parsedLibrary));
+ }
+
+ // Build shared transform module.
+ if (moduleFromFile) {
+ sharedTransformModule =
+ std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
} else if (moduleBuilder) {
// TODO: better location story.
- auto location = UnknownLoc::get(context);
+ auto loc = UnknownLoc::get(context);
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
- ModuleOp::create(location, "__transform"));
+ ModuleOp::create(loc, "__transform"));
OpBuilder b(context);
b.setInsertionPointToEnd(localModule->get().getBody());
- if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
- if (failed(*result))
+ if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
+ if (failed(*result)) {
+ emitError(loc, "failed to create shared transform module");
return failure();
- module = std::move(localModule);
+ }
+ sharedTransformModule = std::move(localModule);
}
}
- if (!parsedLibrary || !*parsedLibrary)
+ if (parsedLibraries.empty())
return success();
- if (module && *module) {
- if (failed(defineDeclaredSymbols(*module->get().getBody(),
- parsedLibrary.get())))
+ // Merge parsed libraries into one module.
+ // TODO: better location story.
+ auto loc = UnknownLoc::get(context);
+ OwningOpRef<ModuleOp> mergedParsedLibraries =
+ ModuleOp::create(loc, "__transform");
+ mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+ UnitAttr::get(context));
+
+ IRRewriter rewriter(context);
+ rewriter.setInsertionPointToEnd(mergedParsedLibraries->getBody());
+ for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+ rewriter.inlineBlockBefore(parsedLibrary->getBody(),
+ mergedParsedLibraries->getBody(),
+ mergedParsedLibraries->getBody()->end());
+ if (failed(mlir::verify(*mergedParsedLibraries))) {
+ emitError(loc, "failed to verify merged transform module");
+ return failure();
+ }
+ }
+
+ // Merge parsed libraries into shared module or return as library module.
+ if (sharedTransformModule && *sharedTransformModule) {
+ if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+ mergedParsedLibraries.get())))
return failure();
} else {
- libraryModule =
- std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+ transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+ std::move(mergedParsedLibraries));
}
return success();
}
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..c7c055c7567e6f3 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..3e06f2ca0895841 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..28e5256385efaea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,8 @@ class TestTransformDialectInterpreterPass
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
- transformFileName, transformLibraryFileName, debugPayloadRootTag,
+ transformFileName, transformLibraryFileNames,
+ transformLibraryDirNames, debugPayloadRootTag,
debugTransformRootTag, getBinaryName())))
return signalPassFailure();
}
@@ -216,12 +217,16 @@ class TestTransformDialectInterpreterPass
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply. If empty, "
"select the container of the top-level transform op.")};
- Option<std::string> transformLibraryFileName{
- *this, "transform-library-file-name", llvm::cl::init(""),
+ ListOption<std::string> transformLibraryFileNames{
+ *this, "transform-library-file-names", llvm::cl::ZeroOrMore,
+ llvm::cl::desc("Optional filenames containing transform dialect symbol "
+ "definitions to be injected into the transform module.")};
+ ListOption<std::string> transformLibraryDirNames{
+ *this, "transform-library-dir-names", llvm::cl::ZeroOrMore,
llvm::cl::desc(
- "Optional name of the file containing transform dialect symbol "
- "definitions to be injected into the transform module.")};
-
+ "Optional directories containing transform dialect symbol "
+ "definitions to be injected into the transform module. All '.mlir' "
+ "files rooted under this directory will be loaded.")};
Option<bool> testModuleGeneration{
*this, "test-module-generation", llvm::cl::init(false),
llvm::cl::desc("test the generation of the transform module during pass "
More information about the Mlir-commits
mailing list