[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 6 03:47:03 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68330
>From 46621a376a75ec004bf3ef3932ec8c9e7ada2340 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 16:11:00 +0000
Subject: [PATCH] [mlir][Transform] Provide a minimal set of utils that allow
implementing a simple transform dialect interpreter pass
---
.../Dialect/Transform/IR/TransformDialect.td | 29 ++-
.../Transform/IR/TransformInterfaces.h | 5 +-
.../Transforms/TransformInterpreterUtils.h | 78 ++++++++
.../Transform/IR/TransformInterfaces.cpp | 26 +--
.../Transform/Transforms/CMakeLists.txt | 1 +
.../TransformInterpreterPassBase.cpp | 115 +----------
.../Transforms/TransformInterpreterUtils.cpp | 188 ++++++++++++++++++
.../Dialect/Transform/CMakeLists.txt | 3 +
mlir/unittests/Dialect/Transform/Preload.cpp | 90 +++++++++
9 files changed, 410 insertions(+), 125 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 70a76ab9670f907..1097b15da77bb7d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
+ /// Symbol name for the default entry point "named sequence".
+ constexpr const static ::llvm::StringLiteral
+ kTransformEntryPointSymbolName = "__transform_main";
+
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
- constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
- "transform.with_named_sequence";
+ constexpr const static ::llvm::StringLiteral
+ kWithNamedSequenceAttrName = "transform.with_named_sequence";
/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
@@ -65,6 +69,22 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
+ /// Appends the given module as a transform symbol library available to
+ /// all dialect users.
+ void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+ library) {
+ libraryModules.push_back(std::move(library));
+ }
+
+ /// Returns a range of registered library modules.
+ auto getLibraryModules() const {
+ return ::llvm::map_range(
+ libraryModules,
+ [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+ return library.get();
+ });
+ }
+
private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
@@ -123,6 +143,11 @@ def Transform_Dialect : Dialect {
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
+
+ /// Modules containing symbols, e.g. named sequences, that will be
+ /// resolved by the interpreter when used.
+ ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+ libraryModules;
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0e72a93e685e32f..7b37245fc3d117b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -111,7 +111,8 @@ class TransformOptions {
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
- const TransformOptions &options = TransformOptions());
+ const TransformOptions &options = TransformOptions(),
+ bool enforceToplevelTransformOp = true);
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
- const TransformOptions &);
+ const TransformOptions &, bool);
friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
new file mode 100644
index 000000000000000..6640c15d6b05729
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -0,0 +1,78 @@
+//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformInterpreterModule(MLIRContext *context,
+ llvm::StringRef transformFileName,
+ OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where explicit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+/// Returns null if the module is not found.
+ModuleOp getPreloadedTransformInterpreterModule(MLIRContext *context);
+
+/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
+/// that is either:
+/// 1. nested under `root` (takes precedence).
+/// 2. nested under `module`, if not found in `root`.
+/// Reports errors and returns null if no such operation found.
+TransformOpInterface findTransformEntryPoint(
+ Operation *root, ModuleOp module,
+ StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+/// Replaces external symbols in `block` with their (non-external) definitions
+/// from the given module.
+LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions);
+} // namespace detail
+
+/// Standalone util to apply the named sequence `entryPoint` to the payload.
+/// This is done in 3 steps:
+/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
+/// calling detail::findTransformEntryPoint.
+/// 2. if the entry point is found and not nested under
+/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
+/// the `sharedTransformModule`. Note: this may modify the transform IR
+/// embedded with the payload IR.
+/// 3. apply the transform IR to the payload IR, relaxing the requirement that
+/// the transform IR is a top-level transform op. We are applying a named
+/// sequence anyway.
+LogicalResult applyTransformNamedSequence(
+ Operation *payload, ModuleOp transformModule,
+ const TransformOptions &options,
+ StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4a9bb2dba7d660c..4f88b8522e54c80 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//
-LogicalResult
-transform::applyTransforms(Operation *payloadRoot,
- TransformOpInterface transform,
- const RaggedArray<MappedValue> &extraMapping,
- const TransformOptions &options) {
-#ifndef NDEBUG
- if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
- transform->getNumOperands() != 0) {
- transform->emitError()
- << "expected transform to start at the top-level transform op";
- llvm::report_fatal_error("could not run transforms",
- /*gen_crash_diag=*/false);
+LogicalResult transform::applyTransforms(
+ Operation *payloadRoot, TransformOpInterface transform,
+ const RaggedArray<MappedValue> &extraMapping,
+ const TransformOptions &options, bool enforceToplevelTransformOp) {
+ if (enforceToplevelTransformOp) {
+ if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+ transform->getNumOperands() != 0) {
+ return transform->emitError()
+ << "expected transform to start at the top-level transform op";
+ }
+ } else if (failed(
+ detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
+ return failure();
}
-#endif // NDEBUG
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 3f51ef1088f7af6..8774a8b86fb0d91 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp
+ TransformInterpreterUtils.cpp
DEPENDS
MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 23640c92457a89d..193f3c5447c29ea 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@@ -50,34 +51,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
constexpr static llvm::StringLiteral
kTransformDialectTagTransformContainerValue = "transform_container";
-/// Utility to parse the content of a `transformFileName` MLIR file containing
-/// a transform dialect specification.
-static LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
- llvm::StringRef transformFileName,
- OwningOpRef<ModuleOp> &transformModule) {
- if (transformFileName.empty()) {
- LLVM_DEBUG(
- DBGS() << "no transform file name specified, assuming the transform "
- "module is embedded in the IR next to the top-level\n");
- return success();
- }
- // Parse transformFileName content into a ModuleOp.
- std::string errorMessage;
- auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
- if (!memoryBuffer) {
- return emitError(FileLineColLoc::get(
- StringAttr::get(context, transformFileName), 0, 0))
- << "failed to open transform file: " << errorMessage;
- }
- // Tell sourceMgr about this buffer, the parser will pick it up.
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
- transformModule =
- OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
- return success();
-}
-
/// Finds the single top-level transform operation with `root` as ancestor.
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
@@ -302,80 +275,6 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
- MLIRContext &ctx = *definitions->getContext();
- auto consumedName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
- auto readOnlyName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
- for (Operation &op : llvm::make_early_inc_range(block)) {
- LLVM_DEBUG(DBGS() << op << "\n");
- auto symbol = dyn_cast<SymbolOpInterface>(op);
- if (!symbol)
- continue;
- if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
- continue;
-
- LLVM_DEBUG(DBGS() << "looking for definition of symbol "
- << symbol.getNameAttr() << ":");
- SymbolTable symbolTable(definitions);
- Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
- if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
- externalSymbol->getRegion(0).empty()) {
- LLVM_DEBUG(llvm::dbgs() << "not found\n");
- continue;
- }
-
- auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
- auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
- if (!symbolFunc || !externalSymbolFunc) {
- LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
- continue;
- }
-
- LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
- if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
- return symbolFunc.emitError()
- << "external definition has a mismatching signature ("
- << externalSymbolFunc.getFunctionType() << ")";
- }
-
- for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed =
- externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly =
- externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
- else if (isReadonly)
- externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
- continue;
- }
-
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return symbolFunc.emitError()
- << "external definition has mismatching consumption annotations "
- "for argument #"
- << i;
- }
- }
-
- OpBuilder builder(&op);
- builder.setInsertionPoint(&op);
- builder.clone(*externalSymbol);
- symbol->erase();
- }
-
- return success();
-}
-
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -443,8 +342,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
- transformLibraryModule->get())))
+ if (failed(detail::defineDeclaredSymbols(*transformRoot->getBlock(),
+ transformLibraryModule->get())))
return failure();
}
@@ -471,15 +370,15 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
OwningOpRef<ModuleOp> parsedTransformModule;
- if (failed(parseTransformModuleFromFile(context, transformFileName,
- parsedTransformModule)))
+ if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
+ parsedTransformModule)))
return failure();
if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
return failure();
OwningOpRef<ModuleOp> parsedLibraryModule;
- if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
- parsedLibraryModule)))
+ if (failed(detail::parseTransformModuleFromFile(
+ context, transformLibraryFileName, parsedLibraryModule)))
return failure();
if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
return failure();
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
new file mode 100644
index 000000000000000..a68b79813aa0f36
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -0,0 +1,188 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformInterpreterModule(
+ MLIRContext *context, llvm::StringRef transformFileName,
+ OwningOpRef<ModuleOp> &transformModule) {
+ if (transformFileName.empty()) {
+ LLVM_DEBUG(
+ DBGS() << "no transform file name specified, assuming the transform "
+ "module is embedded in the IR next to the top-level\n");
+ return success();
+ }
+ // Parse transformFileName content into a ModuleOp.
+ std::string errorMessage;
+ auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+ if (!memoryBuffer) {
+ return emitError(FileLineColLoc::get(
+ StringAttr::get(context, transformFileName), 0, 0))
+ << "failed to open transform file: " << errorMessage;
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ transformModule =
+ OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+ return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformInterpreterModule(
+ MLIRContext *context) {
+ auto preloadedLibraryRange =
+ context->getOrLoadDialect<transform::TransformDialect>()
+ ->getLibraryModules();
+ if (!preloadedLibraryRange.empty())
+ return *preloadedLibraryRange.begin();
+ return ModuleOp();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
+ StringRef entryPoint) {
+ SmallVector<Operation *, 2> l{root};
+ if (module)
+ l.push_back(module);
+ for (Operation *op : l) {
+ transform::TransformOpInterface transform = nullptr;
+ op->walk<WalkOrder::PreOrder>(
+ [&](transform::NamedSequenceOp namedSequenceOp) {
+ if (namedSequenceOp.getSymName() == entryPoint) {
+ transform = cast<transform::TransformOpInterface>(
+ namedSequenceOp.getOperation());
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (transform)
+ return transform;
+ }
+ auto diag = root->emitError()
+ << "could not find a nested named sequence with name: "
+ << entryPoint;
+ return nullptr;
+}
+
+LogicalResult transform::detail::defineDeclaredSymbols(Block &block,
+ ModuleOp definitions) {
+ MLIRContext &ctx = *definitions->getContext();
+ auto consumedName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
+ for (Operation &op : llvm::make_early_inc_range(block)) {
+ LLVM_DEBUG(DBGS() << op << "\n");
+ auto symbol = dyn_cast<SymbolOpInterface>(op);
+ if (!symbol)
+ continue;
+ if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+ continue;
+
+ LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+ << symbol.getNameAttr() << ":");
+ SymbolTable symbolTable(definitions);
+ Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+ if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+ externalSymbol->getRegion(0).empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "not found\n");
+ continue;
+ }
+
+ auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+ auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+ if (!symbolFunc || !externalSymbolFunc) {
+ LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+ if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+ return symbolFunc.emitError()
+ << "external definition has a mismatching signature ("
+ << externalSymbolFunc.getFunctionType() << ")";
+ }
+
+ for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed =
+ externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly =
+ externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+ else if (isReadonly)
+ externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ continue;
+ }
+
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return symbolFunc.emitError()
+ << "external definition has mismatching consumption annotations "
+ "for argument #"
+ << i;
+ }
+ }
+
+ OpBuilder builder(&op);
+ builder.setInsertionPoint(&op);
+ builder.clone(*externalSymbol);
+ symbol->erase();
+ }
+
+ return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+ Operation *payload, ModuleOp transformModule,
+ const TransformOptions &options, StringRef entryPoint) {
+ Operation *transformRoot =
+ detail::findTransformEntryPoint(payload, transformModule, entryPoint);
+ if (!transformRoot)
+ return failure();
+
+ // `transformModule` may not be modified.
+ if (transformModule && !transformModule->isAncestor(transformRoot)) {
+ if (failed(detail::defineDeclaredSymbols(*transformRoot->getBlock(),
+ transformModule)))
+ return failure();
+ }
+
+ // Apply the transform to the IR, do not enforce top-level constraints.
+ RaggedArray<MappedValue> noExtraMappings;
+ return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
+ noExtraMappings, options,
+ /*enforceToplevelTransformOp=*/false);
+}
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 1fecd21221c91c8..89238a0bdae16ea 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformDialectTests
BuildOnlyExtensionTest.cpp
+ Preload.cpp
)
target_link_libraries(MLIRTransformDialectTests
PRIVATE
MLIRFuncDialect
+ MLIRTestTransformDialect
MLIRTransformDialect
+ MLIRTransformDialectTransforms
)
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
new file mode 100644
index 000000000000000..b6b6e74c3c34e24
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -0,0 +1,90 @@
+//===- Preload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry ®istry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence public @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+ transform.yield
+ }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+})MLIR";
+
+TEST(Preload, ContextPreloadConstructedLibrary) {
+ registerPassManagerCLOptions();
+
+ MLIRContext context;
+ auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+ DialectRegistry registry;
+ ::test::registerTestTransformDialectExtension(registry);
+ registry.applyExtensions(&context);
+ ParserConfig parserConfig(&context);
+
+ OwningOpRef<ModuleOp> inputModule =
+ parseSourceString<ModuleOp>(input, parserConfig, "<input>");
+ EXPECT_TRUE(inputModule) << "failed to parse input module";
+
+ OwningOpRef<ModuleOp> transformLibrary =
+ parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
+ EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
+ dialect->registerLibraryModule(std::move(transformLibrary));
+
+ ModuleOp retrievedTransformLibrary =
+ transform::detail::getPreloadedTransformInterpreterModule(&context);
+ EXPECT_TRUE(retrievedTransformLibrary)
+ << "failed to retrieve transform module";
+
+ transform::TransformOpInterface entryPoint =
+ transform::detail::findTransformEntryPoint(inputModule->getOperation(),
+ retrievedTransformLibrary);
+ EXPECT_TRUE(entryPoint) << "failed to find entry point";
+
+ LogicalResult res = transform::detail::defineDeclaredSymbols(
+ inputModule->getBodyRegion().front(), retrievedTransformLibrary);
+ EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
+
+ transform::TransformOptions options;
+ res = transform::applyTransformNamedSequence(
+ inputModule->getOperation(), retrievedTransformLibrary, options);
+ EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
+}
More information about the Mlir-commits
mailing list