[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 6 03:47:03 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68330

>From 46621a376a75ec004bf3ef3932ec8c9e7ada2340 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 16:11:00 +0000
Subject: [PATCH] [mlir][Transform] Provide a minimal set of utils that allow
 implementing a simple transform dialect interpreter pass

---
 .../Dialect/Transform/IR/TransformDialect.td  |  29 ++-
 .../Transform/IR/TransformInterfaces.h        |   5 +-
 .../Transforms/TransformInterpreterUtils.h    |  78 ++++++++
 .../Transform/IR/TransformInterfaces.cpp      |  26 +--
 .../Transform/Transforms/CMakeLists.txt       |   1 +
 .../TransformInterpreterPassBase.cpp          | 115 +----------
 .../Transforms/TransformInterpreterUtils.cpp  | 188 ++++++++++++++++++
 .../Dialect/Transform/CMakeLists.txt          |   3 +
 mlir/unittests/Dialect/Transform/Preload.cpp  |  90 +++++++++
 9 files changed, 410 insertions(+), 125 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
 create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
 create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 70a76ab9670f907..1097b15da77bb7d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {
 
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
+    /// Symbol name for the default entry point "named sequence".
+    constexpr const static ::llvm::StringLiteral
+        kTransformEntryPointSymbolName = "__transform_main";
+    
     /// Name of the attribute attachable to the symbol table operation
     /// containing named sequences. This is used to trigger verification.
-    constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
-        "transform.with_named_sequence";
+    constexpr const static ::llvm::StringLiteral
+        kWithNamedSequenceAttrName = "transform.with_named_sequence";
 
     /// Name of the attribute attachable to an operation so it can be
     /// identified as root by the default interpreter pass.
@@ -65,6 +69,22 @@ def Transform_Dialect : Dialect {
     using ExtensionTypePrintingHook =
         std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
 
+    /// Appends the given module as a transform symbol library available to
+    /// all dialect users.
+    void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+                                library) {
+      libraryModules.push_back(std::move(library));
+    }
+
+    /// Returns a range of registered library modules.
+    auto getLibraryModules() const {
+      return ::llvm::map_range(
+          libraryModules,
+          [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+        return library.get();
+      });
+    }
+
   private:
     /// Registers operations specified as template parameters with this
     /// dialect. Checks that they implement the required interfaces.
@@ -123,6 +143,11 @@ def Transform_Dialect : Dialect {
     /// lookups when the type is fully constructed.
     ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
         typePrintingHooks;
+
+    /// Modules containing symbols, e.g. named sequences, that will be
+    /// resolved by the interpreter when used.
+    ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+        libraryModules;
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0e72a93e685e32f..7b37245fc3d117b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -111,7 +111,8 @@ class TransformOptions {
 LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                 const RaggedArray<MappedValue> &extraMapping = {},
-                const TransformOptions &options = TransformOptions());
+                const TransformOptions &options = TransformOptions(),
+                bool enforceToplevelTransformOp = true);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
                                        const RaggedArray<MappedValue> &,
-                                       const TransformOptions &);
+                                       const TransformOptions &, bool);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
new file mode 100644
index 000000000000000..6640c15d6b05729
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -0,0 +1,78 @@
+//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformInterpreterModule(MLIRContext *context,
+                                llvm::StringRef transformFileName,
+                                OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where explicit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+/// Returns null if the module is not found.
+ModuleOp getPreloadedTransformInterpreterModule(MLIRContext *context);
+
+/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
+/// that is either:
+///   1. nested under `root` (takes precedence).
+///   2. nested under `module`, if not found in `root`.
+/// Reports errors and returns null if no such operation found.
+TransformOpInterface findTransformEntryPoint(
+    Operation *root, ModuleOp module,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+/// Replaces external symbols in `block` with their (non-external) definitions
+/// from the given module.
+LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions);
+} // namespace detail
+
+/// Standalone util to apply the named sequence `entryPoint` to the payload.
+/// This is done in 3 steps:
+///   1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
+///   calling detail::findTransformEntryPoint.
+///   2. if the entry point is found and not nested under
+///   `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
+///   the `sharedTransformModule`. Note: this may modify the transform IR
+///   embedded with the payload IR.
+///   3. apply the transform IR to the payload IR, relaxing the requirement that
+///   the transform IR is a top-level transform op. We are applying a named
+///   sequence anyway.
+LogicalResult applyTransformNamedSequence(
+    Operation *payload, ModuleOp transformModule,
+    const TransformOptions &options,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4a9bb2dba7d660c..4f88b8522e54c80 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 // Entry point.
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-transform::applyTransforms(Operation *payloadRoot,
-                           TransformOpInterface transform,
-                           const RaggedArray<MappedValue> &extraMapping,
-                           const TransformOptions &options) {
-#ifndef NDEBUG
-  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
-      transform->getNumOperands() != 0) {
-    transform->emitError()
-        << "expected transform to start at the top-level transform op";
-    llvm::report_fatal_error("could not run transforms",
-                             /*gen_crash_diag=*/false);
+LogicalResult transform::applyTransforms(
+    Operation *payloadRoot, TransformOpInterface transform,
+    const RaggedArray<MappedValue> &extraMapping,
+    const TransformOptions &options, bool enforceToplevelTransformOp) {
+  if (enforceToplevelTransformOp) {
+    if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+        transform->getNumOperands() != 0) {
+      return transform->emitError()
+             << "expected transform to start at the top-level transform op";
+    }
+  } else if (failed(
+                 detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
+    return failure();
   }
-#endif // NDEBUG
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 3f51ef1088f7af6..8774a8b86fb0d91 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
   InferEffects.cpp
   TransformInterpreterPassBase.cpp
+  TransformInterpreterUtils.cpp
 
   DEPENDS
   MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 23640c92457a89d..193f3c5447c29ea 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -50,34 +51,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
 constexpr static llvm::StringLiteral
     kTransformDialectTagTransformContainerValue = "transform_container";
 
-/// Utility to parse the content of a `transformFileName` MLIR file containing
-/// a transform dialect specification.
-static LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
-                             llvm::StringRef transformFileName,
-                             OwningOpRef<ModuleOp> &transformModule) {
-  if (transformFileName.empty()) {
-    LLVM_DEBUG(
-        DBGS() << "no transform file name specified, assuming the transform "
-                  "module is embedded in the IR next to the top-level\n");
-    return success();
-  }
-  // Parse transformFileName content into a ModuleOp.
-  std::string errorMessage;
-  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
-  if (!memoryBuffer) {
-    return emitError(FileLineColLoc::get(
-               StringAttr::get(context, transformFileName), 0, 0))
-           << "failed to open transform file: " << errorMessage;
-  }
-  // Tell sourceMgr about this buffer, the parser will pick it up.
-  llvm::SourceMgr sourceMgr;
-  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
-  transformModule =
-      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
-  return success();
-}
-
 /// Finds the single top-level transform operation with `root` as ancestor.
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
@@ -302,80 +275,6 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
-
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
-    }
-
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
-      continue;
-    }
-
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
-    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
-      return symbolFunc.emitError()
-             << "external definition has a mismatching signature ("
-             << externalSymbolFunc.getFunctionType() << ")";
-    }
-
-    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-      bool isExternalConsumed =
-          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isExternalReadonly =
-          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      if (!isExternalConsumed && !isExternalReadonly) {
-        if (isConsumed)
-          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
-        else if (isReadonly)
-          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
-        continue;
-      }
-
-      if ((isExternalConsumed && !isConsumed) ||
-          (isExternalReadonly && !isReadonly)) {
-        return symbolFunc.emitError()
-               << "external definition has mismatching consumption annotations "
-                  "for argument #"
-               << i;
-      }
-    }
-
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
-    symbol->erase();
-  }
-
-  return success();
-}
-
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -443,8 +342,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     transformLibraryModule->get())))
+    if (failed(detail::defineDeclaredSymbols(*transformRoot->getBlock(),
+                                             transformLibraryModule->get())))
       return failure();
   }
 
@@ -471,15 +370,15 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
   OwningOpRef<ModuleOp> parsedTransformModule;
-  if (failed(parseTransformModuleFromFile(context, transformFileName,
-                                          parsedTransformModule)))
+  if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
+                                                  parsedTransformModule)))
     return failure();
   if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
     return failure();
 
   OwningOpRef<ModuleOp> parsedLibraryModule;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibraryModule)))
+  if (failed(detail::parseTransformModuleFromFile(
+          context, transformLibraryFileName, parsedLibraryModule)))
     return failure();
   if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
     return failure();
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
new file mode 100644
index 000000000000000..a68b79813aa0f36
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -0,0 +1,188 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformInterpreterModule(
+    MLIRContext *context, llvm::StringRef transformFileName,
+    OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to open transform file: " << errorMessage;
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformInterpreterModule(
+    MLIRContext *context) {
+  auto preloadedLibraryRange =
+      context->getOrLoadDialect<transform::TransformDialect>()
+          ->getLibraryModules();
+  if (!preloadedLibraryRange.empty())
+    return *preloadedLibraryRange.begin();
+  return ModuleOp();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
+                                           StringRef entryPoint) {
+  SmallVector<Operation *, 2> l{root};
+  if (module)
+    l.push_back(module);
+  for (Operation *op : l) {
+    transform::TransformOpInterface transform = nullptr;
+    op->walk<WalkOrder::PreOrder>(
+        [&](transform::NamedSequenceOp namedSequenceOp) {
+          if (namedSequenceOp.getSymName() == entryPoint) {
+            transform = cast<transform::TransformOpInterface>(
+                namedSequenceOp.getOperation());
+            return WalkResult::interrupt();
+          }
+          return WalkResult::advance();
+        });
+    if (transform)
+      return transform;
+  }
+  auto diag = root->emitError()
+              << "could not find a nested named sequence with name: "
+              << entryPoint;
+  return nullptr;
+}
+
+LogicalResult transform::detail::defineDeclaredSymbols(Block &block,
+                                                       ModuleOp definitions) {
+  MLIRContext &ctx = *definitions->getContext();
+  auto consumedName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  for (Operation &op : llvm::make_early_inc_range(block)) {
+    LLVM_DEBUG(DBGS() << op << "\n");
+    auto symbol = dyn_cast<SymbolOpInterface>(op);
+    if (!symbol)
+      continue;
+    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+      continue;
+
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+                      << symbol.getNameAttr() << ":");
+    SymbolTable symbolTable(definitions);
+    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+        externalSymbol->getRegion(0).empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "not found\n");
+      continue;
+    }
+
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+    if (!symbolFunc || !externalSymbolFunc) {
+      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+      continue;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+      return symbolFunc.emitError()
+             << "external definition has a mismatching signature ("
+             << externalSymbolFunc.getFunctionType() << ")";
+    }
+
+    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+      bool isExternalConsumed =
+          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isExternalReadonly =
+          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      if (!isExternalConsumed && !isExternalReadonly) {
+        if (isConsumed)
+          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+        else if (isReadonly)
+          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+        continue;
+      }
+
+      if ((isExternalConsumed && !isConsumed) ||
+          (isExternalReadonly && !isReadonly)) {
+        return symbolFunc.emitError()
+               << "external definition has mismatching consumption annotations "
+                  "for argument #"
+               << i;
+      }
+    }
+
+    OpBuilder builder(&op);
+    builder.setInsertionPoint(&op);
+    builder.clone(*externalSymbol);
+    symbol->erase();
+  }
+
+  return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+    Operation *payload, ModuleOp transformModule,
+    const TransformOptions &options, StringRef entryPoint) {
+  Operation *transformRoot =
+      detail::findTransformEntryPoint(payload, transformModule, entryPoint);
+  if (!transformRoot)
+    return failure();
+
+  // `transformModule` may not be modified.
+  if (transformModule && !transformModule->isAncestor(transformRoot)) {
+    if (failed(detail::defineDeclaredSymbols(*transformRoot->getBlock(),
+                                             transformModule)))
+      return failure();
+  }
+
+  // Apply the transform to the IR, do not enforce top-level constraints.
+  RaggedArray<MappedValue> noExtraMappings;
+  return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
+                         noExtraMappings, options,
+                         /*enforceToplevelTransformOp=*/false);
+}
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 1fecd21221c91c8..89238a0bdae16ea 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRTransformDialectTests
   BuildOnlyExtensionTest.cpp
+  Preload.cpp
 )
 target_link_libraries(MLIRTransformDialectTests
   PRIVATE
   MLIRFuncDialect
+  MLIRTestTransformDialect
   MLIRTransformDialect
+  MLIRTransformDialectTransforms
 )
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
new file mode 100644
index 000000000000000..b6b6e74c3c34e24
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -0,0 +1,90 @@
+//===- Preload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry &registry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence public @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+    transform.yield
+  }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+})MLIR";
+
+TEST(Preload, ContextPreloadConstructedLibrary) {
+  registerPassManagerCLOptions();
+
+  MLIRContext context;
+  auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+  DialectRegistry registry;
+  ::test::registerTestTransformDialectExtension(registry);
+  registry.applyExtensions(&context);
+  ParserConfig parserConfig(&context);
+
+  OwningOpRef<ModuleOp> inputModule =
+      parseSourceString<ModuleOp>(input, parserConfig, "<input>");
+  EXPECT_TRUE(inputModule) << "failed to parse input module";
+
+  OwningOpRef<ModuleOp> transformLibrary =
+      parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
+  EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
+  dialect->registerLibraryModule(std::move(transformLibrary));
+
+  ModuleOp retrievedTransformLibrary =
+      transform::detail::getPreloadedTransformInterpreterModule(&context);
+  EXPECT_TRUE(retrievedTransformLibrary)
+      << "failed to retrieve transform module";
+
+  transform::TransformOpInterface entryPoint =
+      transform::detail::findTransformEntryPoint(inputModule->getOperation(),
+                                                 retrievedTransformLibrary);
+  EXPECT_TRUE(entryPoint) << "failed to find entry point";
+
+  LogicalResult res = transform::detail::defineDeclaredSymbols(
+      inputModule->getBodyRegion().front(), retrievedTransformLibrary);
+  EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
+
+  transform::TransformOptions options;
+  res = transform::applyTransformNamedSequence(
+      inputModule->getOperation(), retrievedTransformLibrary, options);
+  EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
+}



More information about the Mlir-commits mailing list