[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 6 05:08:02 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68330
>From cdfa54029105b464d72fedd0d47dedae2fbef01d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 16:11:00 +0000
Subject: [PATCH] [mlir][Transform] Provide a minimal set of utils that allow
implementing a simple transform dialect interpreter pass
---
.../Dialect/Transform/IR/TransformDialect.td | 29 +-
.../Transform/IR/TransformInterfaces.h | 5 +-
.../Transforms/TransformInterpreterUtils.h | 89 +++++
.../Transform/IR/TransformInterfaces.cpp | 26 +-
.../Transform/Transforms/CMakeLists.txt | 1 +
.../TransformInterpreterPassBase.cpp | 284 +--------------
.../Transforms/TransformInterpreterUtils.cpp | 337 ++++++++++++++++++
.../Dialect/Transform/CMakeLists.txt | 3 +
mlir/unittests/Dialect/Transform/Preload.cpp | 92 +++++
9 files changed, 577 insertions(+), 289 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index f28205a25507025..ad6804673b770ca 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
+ /// Symbol name for the default entry point "named sequence".
+ constexpr const static ::llvm::StringLiteral
+ kTransformEntryPointSymbolName = "__transform_main";
+
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
- constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
- "transform.with_named_sequence";
+ constexpr const static ::llvm::StringLiteral
+ kWithNamedSequenceAttrName = "transform.with_named_sequence";
/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
@@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
+ /// Appends the given module as a transform symbol library available to
+ /// all dialect users.
+ void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+ library) {
+ libraryModules.push_back(std::move(library));
+ }
+
+ /// Returns a range of registered library modules.
+ auto getLibraryModules() const {
+ return ::llvm::map_range(
+ libraryModules,
+ [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+ return library.get();
+ });
+ }
+
private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
@@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
+
+ /// Modules containing symbols, e.g. named sequences, that will be
+ /// resolved by the interpreter when used.
+ ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+ libraryModules;
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0e72a93e685e32f..7b37245fc3d117b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -111,7 +111,8 @@ class TransformOptions {
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
- const TransformOptions &options = TransformOptions());
+ const TransformOptions &options = TransformOptions(),
+ bool enforceToplevelTransformOp = true);
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
- const TransformOptions &);
+ const TransformOptions &, bool);
friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
new file mode 100644
index 000000000000000..36c80e6fd61d3c1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -0,0 +1,89 @@
+//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformModuleFromFile(MLIRContext *context,
+ llvm::StringRef transformFileName,
+ OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where explicit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+/// Returns null if the module is not found.
+ModuleOp getPreloadedTransformModule(MLIRContext *context);
+
+/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
+/// that is either:
+/// 1. nested under `root` (takes precedence).
+/// 2. nested under `module`, if not found in `root`.
+/// Reports errors and returns null if no such operation found.
+TransformOpInterface findTransformEntryPoint(
+ Operation *root, ModuleOp module,
+ StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+// TODO: Reconsider cloning individual ops rather than forcing users of the
+// function to clone (or move) `other` in order to improve efficiency.
+// This might primarily make sense if we can also prune the symbols that
+// are merged to a subset (such as those that are actually used).
+LogicalResult mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other);
+} // namespace detail
+
+/// Standalone util to apply the named sequence `entryPoint` to the payload.
+/// This is done in 3 steps:
+/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
+/// calling detail::findTransformEntryPoint.
+/// 2. if the entry point is found and not nested under
+/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
+/// the `sharedTransformModule`. Note: this may modify the transform IR
+/// embedded with the payload IR.
+/// 3. apply the transform IR to the payload IR, relaxing the requirement that
+/// the transform IR is a top-level transform op. We are applying a named
+/// sequence anyway.
+LogicalResult applyTransformNamedSequence(
+ Operation *payload, ModuleOp transformModule,
+ const TransformOptions &options,
+ StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4a9bb2dba7d660c..4f88b8522e54c80 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//
-LogicalResult
-transform::applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const RaggedArray<MappedValue> &extraMapping,
- const TransformOptions &options) {
-#ifndef NDEBUG
- if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
- transform->getNumOperands() != 0) {
- transform->emitError()
- << "expected transform to start at the top-level transform op";
- llvm::report_fatal_error("could not run transforms",
- /*gen_crash_diag=*/false);
+LogicalResult transform::applyTransforms(
+ Operation *payloadRoot, TransformOpInterface transform,
+ const RaggedArray<MappedValue> &extraMapping,
+ const TransformOptions &options, bool enforceToplevelTransformOp) {
+ if (enforceToplevelTransformOp) {
+ if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+ transform->getNumOperands() != 0) {
+ return transform->emitError()
+ << "expected transform to start at the top-level transform op";
+ }
+ } else if (failed(
+ detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
+ return failure();
}
-#endif // NDEBUG
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 3f51ef1088f7af6..8774a8b86fb0d91 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp
+ TransformInterpreterUtils.cpp
DEPENDS
MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 764d7e25854206e..ebfd7269f696bbc 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@@ -51,34 +52,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
constexpr static llvm::StringLiteral
kTransformDialectTagTransformContainerValue = "transform_container";
-/// Utility to parse the content of a `transformFileName` MLIR file containing
-/// a transform dialect specification.
-static LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
- llvm::StringRef transformFileName,
- OwningOpRef<ModuleOp> &transformModule) {
- if (transformFileName.empty()) {
- LLVM_DEBUG(
- DBGS() << "no transform file name specified, assuming the transform "
- "module is embedded in the IR next to the top-level\n");
- return success();
- }
- // Parse transformFileName content into a ModuleOp.
- std::string errorMessage;
- auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
- if (!memoryBuffer) {
- return emitError(FileLineColLoc::get(
- StringAttr::get(context, transformFileName), 0, 0))
- << "failed to open transform file: " << errorMessage;
- }
- // Tell sourceMgr about this buffer, the parser will pick it up.
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
- transformModule =
- OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
- return success();
-}
-
/// Finds the single top-level transform operation with `root` as ancestor.
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
@@ -295,239 +268,6 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}
-/// Return whether `func1` can be merged into `func2`. For that to work `func1`
-/// has to be a declaration (aka has to be external) and `func2` either has to
-/// be a declaration as well, or it has to be public (otherwise, it wouldn't
-/// be visible by `func1`).
-static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
- return func1.isExternal() && (func2.isPublic() || func2.isExternal());
-}
-
-/// Merge `func1` into `func2`. The two ops must be inside the same parent op
-/// and mergable according to `canMergeInto`. The function erases `func1` such
-/// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
- FunctionOpInterface func2) {
- assert(canMergeInto(func1, func2));
- assert(func1->getParentOp() == func2->getParentOp() &&
- "expected func1 and func2 to be in the same parent op");
-
- // Check that function signatures match.
- if (func1.getFunctionType() != func2.getFunctionType()) {
- return func1.emitError()
- << "external definition has a mismatching signature ("
- << func2.getFunctionType() << ")";
- }
-
- // Check and merge argument attributes.
- MLIRContext *context = func1->getContext();
- auto *td = context->getLoadedDialect<transform::TransformDialect>();
- StringAttr consumedName = td->getConsumedAttrName();
- StringAttr readOnlyName = td->getReadOnlyAttrName();
- for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- func2.setArgAttr(i, consumedName, UnitAttr::get(context));
- else if (isReadonly)
- func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
- continue;
- }
-
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return func1.emitError()
- << "external definition has mismatching consumption "
- "annotations for argument #"
- << i;
- }
- }
-
- // `func1` is the external one, so we can remove it.
- assert(func1.isExternal());
- func1->erase();
-
- return success();
-}
-
-/// Merge all symbols from `other` into `target`. Both ops need to implement the
-/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
-/// modified by this function and might not verify after the function returns.
-/// Upon merging, private symbols may be renamed in order to avoid collisions in
-/// the result. Public symbols may not collide, with the exception of
-/// instances of `SymbolOpInterface`, where collisions are allowed if at least
-/// one of the two is external, in which case the other op preserved (or any one
-/// of the two if both are external).
-// TODO: Reconsider cloning individual ops rather than forcing users of the
-// function to clone (or move) `other` in order to improve efficiency.
-// This might primarily make sense if we can also prune the symbols that
-// are merged to a subset (such as those that are actually used).
-static LogicalResult mergeSymbolsInto(Operation *target,
- OwningOpRef<Operation *> other) {
- assert(target->hasTrait<OpTrait::SymbolTable>() &&
- "requires target to implement the 'SymbolTable' trait");
- assert(other->hasTrait<OpTrait::SymbolTable>() &&
- "requires target to implement the 'SymbolTable' trait");
-
- SymbolTable targetSymbolTable(target);
- SymbolTable otherSymbolTable(*other);
-
- // Step 1:
- //
- // Rename private symbols in both ops in order to resolve conflicts that can
- // be resolved that way.
- LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
- // TODO: Do we *actually* need to test in both directions?
- for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
- SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
- SmallVector<SymbolTable *, 2>{&otherSymbolTable,
- &targetSymbolTable})) {
- Operation *symbolTableOp = symbolTable->getOp();
- for (Operation &op : symbolTableOp->getRegion(0).front()) {
- auto symbolOp = dyn_cast<SymbolOpInterface>(op);
- if (!symbolOp)
- continue;
- StringAttr name = symbolOp.getNameAttr();
- LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
-
- // Check if there is a colliding op in the other module.
- auto collidingOp =
- cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
- if (!collidingOp)
- continue;
-
- LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
-
- // Collisions are fine if both opt are functions and can be merged.
- if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
- collidingFuncOp =
- dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
- funcOp && collidingFuncOp) {
- if (canMergeInto(funcOp, collidingFuncOp) ||
- canMergeInto(collidingFuncOp, funcOp)) {
- LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
- "will be merged\n");
- continue;
- }
-
- // If they can't be merged, proceed like any other collision.
- LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
- }
-
- // Collision can be resolved by renaming if one of the ops is private.
- auto renameToUnique =
- [&](SymbolOpInterface op, SymbolOpInterface otherOp,
- SymbolTable &symbolTable,
- SymbolTable &otherSymbolTable) -> LogicalResult {
- LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
- FailureOr<StringAttr> maybeNewName =
- symbolTable.renameToUnique(op, {&otherSymbolTable});
- if (failed(maybeNewName)) {
- InFlightDiagnostic diag = op->emitError("failed to rename symbol");
- diag.attachNote(otherOp->getLoc())
- << "attempted renaming due to collision with this op";
- return diag;
- }
- LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
- << "\n");
- return success();
- };
-
- if (symbolOp.isPrivate()) {
- if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
- *otherSymbolTable)))
- return failure();
- continue;
- }
- if (collidingOp.isPrivate()) {
- if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
- *symbolTable)))
- return failure();
- continue;
- }
-
- LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
- InFlightDiagnostic diag = symbolOp.emitError()
- << "doubly defined symbol @" << name.getValue();
- diag.attachNote(collidingOp->getLoc()) << "previously defined here";
- return diag;
- }
- }
-
- // TODO: This duplicates pass infrastructure. We should split this pass into
- // several and let the pass infrastructure do the verification.
- for (auto *op : SmallVector<Operation *>{target, *other}) {
- if (failed(mlir::verify(op)))
- return op->emitError() << "failed to verify input op after renaming";
- }
-
- // Step 2:
- //
- // Move all ops from `other` into target and merge public symbols.
- LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
- {
- SmallVector<SymbolOpInterface> opsToMove;
- for (Operation &op : other->getRegion(0).front()) {
- if (auto symbol = dyn_cast<SymbolOpInterface>(op))
- opsToMove.push_back(symbol);
- }
-
- for (SymbolOpInterface op : opsToMove) {
- // Remember potentially colliding op in the target module.
- auto collidingOp = cast_or_null<SymbolOpInterface>(
- targetSymbolTable.lookup(op.getNameAttr()));
-
- // Move op even if we get a collision.
- LLVM_DEBUG(DBGS() << " moving @" << op.getName());
- op->moveBefore(&target->getRegion(0).front(),
- target->getRegion(0).front().end());
-
- // If there is no collision, we are done.
- if (!collidingOp) {
- LLVM_DEBUG(llvm::dbgs() << " without collision\n");
- continue;
- }
-
- // The two colliding ops must both be functions because we have already
- // emitted errors otherwise earlier.
- auto funcOp = cast<FunctionOpInterface>(op.getOperation());
- auto collidingFuncOp =
- cast<FunctionOpInterface>(collidingOp.getOperation());
-
- // Both ops are in the target module now and can be treated symmetrically,
- // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
- if (!canMergeInto(funcOp, collidingFuncOp)) {
- std::swap(funcOp, collidingFuncOp);
- }
- assert(canMergeInto(funcOp, collidingFuncOp));
-
- LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
- << collidingFuncOp.getLoc() << ":\n"
- << collidingFuncOp << "\n");
-
- // Update symbol table. This works with or without the previous `swap`.
- targetSymbolTable.remove(funcOp);
- targetSymbolTable.insert(collidingFuncOp);
- assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
-
- // Do the actual merging.
- if (failed(mergeInto(funcOp, collidingFuncOp))) {
- return failure();
- }
- }
- }
-
- if (failed(mlir::verify(target)))
- return target->emitError()
- << "failed to verify target op after merging symbols";
-
- LLVM_DEBUG(DBGS() << "done merging ops\n");
- return success();
-}
-
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -595,9 +335,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(
- mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
- transformLibraryModule->get()->clone())))
+ if (failed(detail::mergeSymbolsInto(
+ SymbolTable::getNearestSymbolTable(transformRoot),
+ transformLibraryModule->get()->clone())))
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
}
@@ -683,8 +423,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
OwningOpRef<ModuleOp> moduleFromFile;
{
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
- if (failed(parseTransformModuleFromFile(context, transformFileName,
- moduleFromFile)))
+ if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
+ moduleFromFile)))
return emitError(loc) << "failed to parse transform module";
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
return emitError(loc) << "failed to verify transform module";
@@ -701,8 +441,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
for (const std::string &libraryFileName : libraryFileNames) {
OwningOpRef<ModuleOp> parsedLibrary;
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
- if (failed(parseTransformModuleFromFile(context, libraryFileName,
- parsedLibrary)))
+ if (failed(detail::parseTransformModuleFromFile(context, libraryFileName,
+ parsedLibrary)))
return emitError(loc) << "failed to parse transform library module";
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
return emitError(loc) << "failed to verify transform library module";
@@ -741,8 +481,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
IRRewriter rewriter(context);
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
- if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
- std::move(parsedLibrary))))
+ if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(),
+ std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
}
@@ -751,8 +491,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
// Use parsed libaries to resolve symbols in shared transform module or return
// as separate library module.
if (sharedTransformModule && *sharedTransformModule) {
- if (failed(mergeSymbolsInto(sharedTransformModule->get(),
- std::move(mergedParsedLibraries))))
+ if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(),
+ std::move(mergedParsedLibraries))))
return (*sharedTransformModule)->emitError()
<< "failed to merge symbols from library files "
"into shared transform module";
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
new file mode 100644
index 000000000000000..1a6ebdd16232e8a
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -0,0 +1,337 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformModuleFromFile(
+ MLIRContext *context, llvm::StringRef transformFileName,
+ OwningOpRef<ModuleOp> &transformModule) {
+ if (transformFileName.empty()) {
+ LLVM_DEBUG(
+ DBGS() << "no transform file name specified, assuming the transform "
+ "module is embedded in the IR next to the top-level\n");
+ return success();
+ }
+ // Parse transformFileName content into a ModuleOp.
+ std::string errorMessage;
+ auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+ if (!memoryBuffer) {
+ return emitError(FileLineColLoc::get(
+ StringAttr::get(context, transformFileName), 0, 0))
+ << "failed to open transform file: " << errorMessage;
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ transformModule =
+ OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+ return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
+ auto preloadedLibraryRange =
+ context->getOrLoadDialect<transform::TransformDialect>()
+ ->getLibraryModules();
+ if (!preloadedLibraryRange.empty())
+ return *preloadedLibraryRange.begin();
+ return ModuleOp();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
+ StringRef entryPoint) {
+ SmallVector<Operation *, 2> l{root};
+ if (module)
+ l.push_back(module);
+ for (Operation *op : l) {
+ transform::TransformOpInterface transform = nullptr;
+ op->walk<WalkOrder::PreOrder>(
+ [&](transform::NamedSequenceOp namedSequenceOp) {
+ if (namedSequenceOp.getSymName() == entryPoint) {
+ transform = cast<transform::TransformOpInterface>(
+ namedSequenceOp.getOperation());
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (transform)
+ return transform;
+ }
+ auto diag = root->emitError()
+ << "could not find a nested named sequence with name: "
+ << entryPoint;
+ return nullptr;
+}
+
+/// Return whether `func1` can be merged into `func2`. For that to work `func1`
+/// has to be a declaration (aka has to be external) and `func2` either has to
+/// be a declaration as well, or it has to be public (otherwise, it wouldn't
+/// be visible by `func1`).
+static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+static LogicalResult mergeInto(FunctionOpInterface func1,
+ FunctionOpInterface func2) {
+ assert(canMergeInto(func1, func2));
+ assert(func1->getParentOp() == func2->getParentOp() &&
+ "expected func1 and func2 to be in the same parent op");
+
+ // Check that function signatures match.
+ if (func1.getFunctionType() != func2.getFunctionType()) {
+ return func1.emitError()
+ << "external definition has a mismatching signature ("
+ << func2.getFunctionType() << ")";
+ }
+
+ // Check and merge argument attributes.
+ MLIRContext *context = func1->getContext();
+ auto *td = context->getLoadedDialect<transform::TransformDialect>();
+ StringAttr consumedName = td->getConsumedAttrName();
+ StringAttr readOnlyName = td->getReadOnlyAttrName();
+ for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+ else if (isReadonly)
+ func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
+ continue;
+ }
+
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return func1.emitError()
+ << "external definition has mismatching consumption "
+ "annotations for argument #"
+ << i;
+ }
+ }
+
+ // `func1` is the external one, so we can remove it.
+ assert(func1.isExternal());
+ func1->erase();
+
+ return success();
+}
+
+LogicalResult
+transform::detail::mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other) {
+ assert(target->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+ assert(other->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+
+ SymbolTable targetSymbolTable(target);
+ SymbolTable otherSymbolTable(*other);
+
+ // Step 1:
+ //
+ // Rename private symbols in both ops in order to resolve conflicts that can
+ // be resolved that way.
+ LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ // TODO: Do we *actually* need to test in both directions?
+ for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
+ SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
+ SmallVector<SymbolTable *, 2>{&otherSymbolTable,
+ &targetSymbolTable})) {
+ Operation *symbolTableOp = symbolTable->getOp();
+ for (Operation &op : symbolTableOp->getRegion(0).front()) {
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
+ StringAttr name = symbolOp.getNameAttr();
+ LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+
+ // Check if there is a colliding op in the other module.
+ auto collidingOp =
+ cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+ if (!collidingOp)
+ continue;
+
+ LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+
+ // Collisions are fine if both opt are functions and can be merged.
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+ collidingFuncOp =
+ dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+ funcOp && collidingFuncOp) {
+ if (canMergeInto(funcOp, collidingFuncOp) ||
+ canMergeInto(collidingFuncOp, funcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+ "will be merged\n");
+ continue;
+ }
+
+ // If they can't be merged, proceed like any other collision.
+ LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+ }
+
+ // Collision can be resolved by renaming if one of the ops is private.
+ auto renameToUnique =
+ [&](SymbolOpInterface op, SymbolOpInterface otherOp,
+ SymbolTable &symbolTable,
+ SymbolTable &otherSymbolTable) -> LogicalResult {
+ LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+ FailureOr<StringAttr> maybeNewName =
+ symbolTable.renameToUnique(op, {&otherSymbolTable});
+ if (failed(maybeNewName)) {
+ InFlightDiagnostic diag = op->emitError("failed to rename symbol");
+ diag.attachNote(otherOp->getLoc())
+ << "attempted renaming due to collision with this op";
+ return diag;
+ }
+ LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
+ << "\n");
+ return success();
+ };
+
+ if (symbolOp.isPrivate()) {
+ if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+ *otherSymbolTable)))
+ return failure();
+ continue;
+ }
+ if (collidingOp.isPrivate()) {
+ if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+ *symbolTable)))
+ return failure();
+ continue;
+ }
+ LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ InFlightDiagnostic diag = symbolOp.emitError()
+ << "doubly defined symbol @" << name.getValue();
+ diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+ return diag;
+ }
+ }
+
+ // TODO: This duplicates pass infrastructure. We should split this pass into
+ // several and let the pass infrastructure do the verification.
+ for (auto *op : SmallVector<Operation *>{target, *other}) {
+ if (failed(mlir::verify(op)))
+ return op->emitError() << "failed to verify input op after renaming";
+ }
+
+ // Step 2:
+ //
+ // Move all ops from `other` into target and merge public symbols.
+ LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ {
+ SmallVector<SymbolOpInterface> opsToMove;
+ for (Operation &op : other->getRegion(0).front()) {
+ if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+ opsToMove.push_back(symbol);
+ }
+
+ for (SymbolOpInterface op : opsToMove) {
+ // Remember potentially colliding op in the target module.
+ auto collidingOp = cast_or_null<SymbolOpInterface>(
+ targetSymbolTable.lookup(op.getNameAttr()));
+
+ // Move op even if we get a collision.
+ LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ op->moveBefore(&target->getRegion(0).front(),
+ target->getRegion(0).front().end());
+
+ // If there is no collision, we are done.
+ if (!collidingOp) {
+ LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ continue;
+ }
+
+ // The two colliding ops must both be functions because we have already
+ // emitted errors otherwise earlier.
+ auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+ auto collidingFuncOp =
+ cast<FunctionOpInterface>(collidingOp.getOperation());
+
+ // Both ops are in the target module now and can be treated symmetrically,
+ // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+ if (!canMergeInto(funcOp, collidingFuncOp)) {
+ std::swap(funcOp, collidingFuncOp);
+ }
+ assert(canMergeInto(funcOp, collidingFuncOp));
+
+ LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp << "\n");
+
+ // Update symbol table. This works with or without the previous `swap`.
+ targetSymbolTable.remove(funcOp);
+ targetSymbolTable.insert(collidingFuncOp);
+ assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+ // Do the actual merging.
+ if (failed(mergeInto(funcOp, collidingFuncOp))) {
+ return failure();
+ }
+ }
+ }
+
+ if (failed(mlir::verify(target)))
+ return target->emitError()
+ << "failed to verify target op after merging symbols";
+
+ LLVM_DEBUG(DBGS() << "done merging ops\n");
+ return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+ Operation *payload, ModuleOp transformModule,
+ const TransformOptions &options, StringRef entryPoint) {
+ Operation *transformRoot =
+ detail::findTransformEntryPoint(payload, transformModule, entryPoint);
+ if (!transformRoot)
+ return failure();
+
+ // `transformModule` may not be modified.
+ OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
+ if (transformModule && !transformModule->isAncestor(transformRoot)) {
+ if (failed(detail::mergeSymbolsInto(
+ SymbolTable::getNearestSymbolTable(transformRoot),
+ std::move(clonedTransformModule))))
+ return failure();
+ }
+
+ // Apply the transform to the IR, do not enforce top-level constraints.
+ RaggedArray<MappedValue> noExtraMappings;
+ return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
+ noExtraMappings, options,
+ /*enforceToplevelTransformOp=*/false);
+}
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 1fecd21221c91c8..89238a0bdae16ea 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformDialectTests
BuildOnlyExtensionTest.cpp
+ Preload.cpp
)
target_link_libraries(MLIRTransformDialectTests
PRIVATE
MLIRFuncDialect
+ MLIRTestTransformDialect
MLIRTransformDialect
+ MLIRTransformDialectTransforms
)
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
new file mode 100644
index 000000000000000..d3c3044e0e0f776
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -0,0 +1,92 @@
+//===- Preload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry ®istry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+ transform.yield
+ }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+})MLIR";
+
+TEST(Preload, ContextPreloadConstructedLibrary) {
+ registerPassManagerCLOptions();
+
+ MLIRContext context;
+ auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+ DialectRegistry registry;
+ ::test::registerTestTransformDialectExtension(registry);
+ registry.applyExtensions(&context);
+ ParserConfig parserConfig(&context);
+
+ OwningOpRef<ModuleOp> inputModule =
+ parseSourceString<ModuleOp>(input, parserConfig, "<input>");
+ EXPECT_TRUE(inputModule) << "failed to parse input module";
+
+ OwningOpRef<ModuleOp> transformLibrary =
+ parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
+ EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
+ dialect->registerLibraryModule(std::move(transformLibrary));
+
+ ModuleOp retrievedTransformLibrary =
+ transform::detail::getPreloadedTransformModule(&context);
+ EXPECT_TRUE(retrievedTransformLibrary)
+ << "failed to retrieve transform module";
+
+ transform::TransformOpInterface entryPoint =
+ transform::detail::findTransformEntryPoint(inputModule->getOperation(),
+ retrievedTransformLibrary);
+ EXPECT_TRUE(entryPoint) << "failed to find entry point";
+
+ OwningOpRef<Operation *> clonedTransformModule(
+ retrievedTransformLibrary->clone());
+ LogicalResult res = transform::detail::mergeSymbolsInto(
+ inputModule->getOperation(), std::move(clonedTransformModule));
+ EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
+
+ transform::TransformOptions options;
+ res = transform::applyTransformNamedSequence(
+ inputModule->getOperation(), retrievedTransformLibrary, options);
+ EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
+}
More information about the Mlir-commits
mailing list