[Mlir-commits] [mlir] 6a2071c - [mlir][transform] Allow passing various library files to interpreter. (#67120)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 6 03:52:53 PDT 2023
Author: Ingo Müller
Date: 2023-10-06T12:52:49+02:00
New Revision: 6a2071cc6a129dfe645ef4743fda78e76d748f16
URL: https://github.com/llvm/llvm-project/commit/6a2071cc6a129dfe645ef4743fda78e76d748f16
DIFF: https://github.com/llvm/llvm-project/commit/6a2071cc6a129dfe645ef4743fda78e76d748f16.diff
LOG: [mlir][transform] Allow passing various library files to interpreter. (#67120)
The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
Added:
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
Modified:
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
mlir/test/Integration/Dialect/Transform/match_matmul.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
Removed:
mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index a6f0dddebd7eacf..16ef0bc6a739200 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,7 @@ namespace detail {
/// Template-free implementation of TransformInterpreterPassBase::initialize.
LogicalResult interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
- StringRef transformLibraryFileName,
+ ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +48,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName);
@@ -62,11 +62,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
/// transform script. If empty, `debugTransformRootTag` is considered or the
/// pass root operation must contain a single top-level transform op that
/// will be interpreted.
-/// - transformLibraryFileName: if non-empty, the module in this file will be
+/// - transformLibraryPaths: if non-empty, the modules in these files will be
/// merged into the main transform script run by the interpreter before
/// execution. This allows to provide definitions for external functions
-/// used in the main script. Other public symbols in the library module may
-/// lead to collisions with public symbols in the main script.
+/// used in the main script. Other public symbols in the library modules may
+/// lead to collisions with public symbols in the main script and among each
+/// other.
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
/// `kTransformDialectTagAttrName` indicating the single op that is
/// considered the payload root of the transform interpreter; otherwise, the
@@ -118,16 +119,26 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
REQUIRE_PASS_OPTION(transformFileName);
REQUIRE_PASS_OPTION(debugPayloadRootTag);
REQUIRE_PASS_OPTION(debugTransformRootTag);
- REQUIRE_PASS_OPTION(transformLibraryFileName);
#undef REQUIRE_PASS_OPTION
+#define REQUIRE_PASS_LIST_OPTION(NAME) \
+ static_assert( \
+ std::is_same_v< \
+ std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
+ Pass::ListOption<std::string>>, \
+ "required " #NAME " string pass option is missing")
+
+ REQUIRE_PASS_LIST_OPTION(transformLibraryPaths);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
StringRef transformFileName =
static_cast<Concrete *>(this)->transformFileName;
- StringRef transformLibraryFileName =
- static_cast<Concrete *>(this)->transformLibraryFileName;
+ ArrayRef<std::string> transformLibraryPaths =
+ static_cast<Concrete *>(this)->transformLibraryPaths;
return detail::interpreterBaseInitializeImpl(
- context, transformFileName, transformLibraryFileName,
+ context, transformFileName, transformLibraryPaths,
sharedTransformModule, transformLibraryModule,
[this](OpBuilder &builder, Location loc) {
return static_cast<Concrete *>(this)->constructTransformModule(
@@ -162,7 +173,7 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
op, pass->getArgument(), sharedTransformModule,
transformLibraryModule,
/*extraMappings=*/{}, options, pass->transformFileName,
- pass->transformLibraryFileName, pass->debugPayloadRootTag,
+ pass->transformLibraryPaths, pass->debugPayloadRootTag,
pass->debugTransformRootTag, binaryName)) ||
failed(pass->runAfterInterpreter(op))) {
return pass->signalPassFailure();
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 68a735e7ef8e0b3..764d7e25854206e 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
using llvm::sys::fs::TempFile;
Operation *root = getRootOperation(target);
@@ -231,7 +232,7 @@ static void performOptionalDebugActions(
Operation *target, Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
MLIRContext *context = target->getContext();
@@ -284,7 +285,7 @@ static void performOptionalDebugActions(
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName);
+ transformLibraryPaths, binaryName);
});
// Remove temporary attributes if they were set.
@@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
- const Pass::Option<std::string> &transformLibraryFileName,
+ const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName) {
@@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
if (failed(
mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
- return failure();
+ return emitError(transformRoot->getLoc(),
+ "failed to merge library symbols into transform root");
}
// Step 4
@@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
// repro to stderr and/or a file.
performOptionalDebugActions(target, transformRoot, passName,
debugPayloadRootTag, debugTransformRootTag,
- transformLibraryFileName, binaryName);
+ transformLibraryPaths, binaryName);
// Step 5
// ------
@@ -615,55 +617,148 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
extraMappings, options);
}
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+static LogicalResult
+expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
+ SmallVectorImpl<std::string> &fileNames) {
+ for (const std::string &path : paths) {
+ auto loc = FileLineColLoc::get(context, path, 0, 0);
+
+ if (llvm::sys::fs::is_regular_file(path)) {
+ LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
+ fileNames.push_back(path);
+ continue;
+ }
+
+ if (!llvm::sys::fs::is_directory(path)) {
+ return emitError(loc)
+ << "'" << path << "' is neither a file nor a directory";
+ }
+
+ LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
+
+ std::error_code ec;
+ for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
+ it != itEnd && !ec; it.increment(ec)) {
+ const std::string &fileName = it->path();
+
+ if (it->type() != llvm::sys::fs::file_type::regular_file) {
+ LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
+ << "'\n");
+ continue;
+ }
+
+ if (!StringRef(fileName).endswith(".mlir")) {
+ LLVM_DEBUG(DBGS() << " Skipping '" << fileName
+ << "' because it does not end with '.mlir'\n");
+ continue;
+ }
+
+ LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
+ fileNames.push_back(fileName);
+ }
+
+ if (ec)
+ return emitError(loc) << "error while opening files in '" << path
+ << "': " << ec.message();
+ }
+
+ return success();
+}
+
LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
- StringRef transformLibraryFileName,
+ ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
- OwningOpRef<ModuleOp> parsedTransformModule;
- if (failed(parseTransformModuleFromFile(context, transformFileName,
- parsedTransformModule)))
- return failure();
- if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
- return failure();
+ auto unknownLoc = UnknownLoc::get(context);
- OwningOpRef<ModuleOp> parsedLibraryModule;
- if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
- parsedLibraryModule)))
- return failure();
- if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
+ // Parse module from file.
+ OwningOpRef<ModuleOp> moduleFromFile;
+ {
+ auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+ if (failed(parseTransformModuleFromFile(context, transformFileName,
+ moduleFromFile)))
+ return emitError(loc) << "failed to parse transform module";
+ if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
+ return emitError(loc) << "failed to verify transform module";
+ }
+
+ // Assemble list of library files.
+ SmallVector<std::string> libraryFileNames;
+ if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
+ libraryFileNames)))
return failure();
- if (parsedTransformModule) {
- sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
- std::move(parsedTransformModule));
+ // Parse modules from library files.
+ SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+ for (const std::string &libraryFileName : libraryFileNames) {
+ OwningOpRef<ModuleOp> parsedLibrary;
+ auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+ if (failed(parseTransformModuleFromFile(context, libraryFileName,
+ parsedLibrary)))
+ return emitError(loc) << "failed to parse transform library module";
+ if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+ return emitError(loc) << "failed to verify transform library module";
+ parsedLibraries.push_back(std::move(parsedLibrary));
+ }
+
+ // Build shared transform module.
+ if (moduleFromFile) {
+ sharedTransformModule =
+ std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
} else if (moduleBuilder) {
- // TODO: better location story.
- auto location = UnknownLoc::get(context);
+ auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
- ModuleOp::create(location, "__transform"));
+ ModuleOp::create(unknownLoc, "__transform"));
OpBuilder b(context);
b.setInsertionPointToEnd(localModule->get().getBody());
- if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+ if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
if (failed(*result))
- return failure();
+ return (*localModule)->emitError()
+ << "failed to create shared transform module";
sharedTransformModule = std::move(localModule);
}
}
- if (!parsedLibraryModule || !*parsedLibraryModule)
+ if (parsedLibraries.empty())
return success();
+ // Merge parsed libraries into one module.
+ auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
+ OwningOpRef<ModuleOp> mergedParsedLibraries =
+ ModuleOp::create(loc, "__transform");
+ {
+ mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+ UnitAttr::get(context));
+ IRRewriter rewriter(context);
+ // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
+ for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+ if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
+ std::move(parsedLibrary))))
+ return mergedParsedLibraries->emitError()
+ << "failed to verify merged transform module";
+ }
+ }
+
+ // Use parsed libaries to resolve symbols in shared transform module or return
+ // as separate library module.
if (sharedTransformModule && *sharedTransformModule) {
if (failed(mergeSymbolsInto(sharedTransformModule->get(),
- std::move(parsedLibraryModule))))
- return failure();
+ std::move(mergedParsedLibraries))))
+ return (*sharedTransformModule)->emitError()
+ << "failed to merge symbols from library files "
+ "into shared transform module";
} else {
- transformLibraryModule =
- std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
+ transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+ std::move(mergedParsedLibraries));
}
return success();
}
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..c25212fbe98782b 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,7 @@
// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
// RUN: -test-transform-dialect-erase-schedule -cse \
// RUN: | FileCheck %s
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index dd8d141e994da0e..2c4812bf32b0f03 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics
// The external transform script has a declaration to the named sequence @foo,
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
new file mode 100644
index 000000000000000..8b8254976e9aeec
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library})" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir,%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir})" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+ // CHECK: transform.named_sequence @print_message
+ transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+ transform.named_sequence @reference_other_module(!transform.any_op {transform.readonly})
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ include @reference_other_module failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index 7452deb39b6c18d..c1bd071dc138d56 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file
// The definition of the @print_message named sequence is provided in another file. It
@@ -8,6 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
+ // expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -32,6 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
+ // expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -47,6 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
+ // expected-error @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 7d0837abebde32c..339e62072cd5510 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// The definition of the @print_message named sequence is provided in another
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
similarity index 100%
rename from mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
rename to mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
diff --git a/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
new file mode 100644
index 000000000000000..b3d076f4698495f
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+ transform.named_sequence @reference_other_module(%arg0: !transform.any_op) {
+ transform.include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..3914dc6198a69c2 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..bb0f1125fd39716 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-paths=%p/match_matmul_common.mlir' --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 578d9abe4a56ecb..c60b21c918338b4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,7 @@ class TestTransformDialectInterpreterPass
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
- transformFileName, transformLibraryFileName, debugPayloadRootTag,
+ transformFileName, transformLibraryPaths, debugPayloadRootTag,
debugTransformRootTag, getBinaryName())))
return signalPassFailure();
}
@@ -216,9 +216,9 @@ class TestTransformDialectInterpreterPass
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply. If empty, "
"select the container of the top-level transform op.")};
- Option<std::string> transformLibraryFileName{
- *this, "transform-library-file-name", llvm::cl::init(""),
- llvm::cl::desc("Optional name of a file with a module that should be "
+ ListOption<std::string> transformLibraryPaths{
+ *this, "transform-library-paths", llvm::cl::ZeroOrMore,
+ llvm::cl::desc("Optional paths to files with modules that should be "
"merged into the transform module to provide the "
"definitions of external named sequences.")};
More information about the Mlir-commits
mailing list