[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 6 05:08:02 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68330

>From cdfa54029105b464d72fedd0d47dedae2fbef01d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 5 Oct 2023 16:11:00 +0000
Subject: [PATCH] [mlir][Transform] Provide a minimal set of utils that allow
 implementing a simple transform dialect interpreter pass

---
 .../Dialect/Transform/IR/TransformDialect.td  |  29 +-
 .../Transform/IR/TransformInterfaces.h        |   5 +-
 .../Transforms/TransformInterpreterUtils.h    |  89 +++++
 .../Transform/IR/TransformInterfaces.cpp      |  26 +-
 .../Transform/Transforms/CMakeLists.txt       |   1 +
 .../TransformInterpreterPassBase.cpp          | 284 +--------------
 .../Transforms/TransformInterpreterUtils.cpp  | 337 ++++++++++++++++++
 .../Dialect/Transform/CMakeLists.txt          |   3 +
 mlir/unittests/Dialect/Transform/Preload.cpp  |  92 +++++
 9 files changed, 577 insertions(+), 289 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
 create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
 create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index f28205a25507025..ad6804673b770ca 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {
 
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
+    /// Symbol name for the default entry point "named sequence".
+    constexpr const static ::llvm::StringLiteral
+        kTransformEntryPointSymbolName = "__transform_main";
+    
     /// Name of the attribute attachable to the symbol table operation
     /// containing named sequences. This is used to trigger verification.
-    constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
-        "transform.with_named_sequence";
+    constexpr const static ::llvm::StringLiteral
+        kWithNamedSequenceAttrName = "transform.with_named_sequence";
 
     /// Name of the attribute attachable to an operation so it can be
     /// identified as root by the default interpreter pass.
@@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
     using ExtensionTypePrintingHook =
         std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
 
+    /// Appends the given module as a transform symbol library available to
+    /// all dialect users.
+    void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
+                                library) {
+      libraryModules.push_back(std::move(library));
+    }
+
+    /// Returns a range of registered library modules.
+    auto getLibraryModules() const {
+      return ::llvm::map_range(
+          libraryModules,
+          [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+        return library.get();
+      });
+    }
+
   private:
     /// Registers operations specified as template parameters with this
     /// dialect. Checks that they implement the required interfaces.
@@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
     /// lookups when the type is fully constructed.
     ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
         typePrintingHooks;
+
+    /// Modules containing symbols, e.g. named sequences, that will be
+    /// resolved by the interpreter when used.
+    ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+        libraryModules;
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0e72a93e685e32f..7b37245fc3d117b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -111,7 +111,8 @@ class TransformOptions {
 LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                 const RaggedArray<MappedValue> &extraMapping = {},
-                const TransformOptions &options = TransformOptions());
+                const TransformOptions &options = TransformOptions(),
+                bool enforceToplevelTransformOp = true);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
                                        const RaggedArray<MappedValue> &,
-                                       const TransformOptions &);
+                                       const TransformOptions &, bool);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
new file mode 100644
index 000000000000000..36c80e6fd61d3c1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -0,0 +1,89 @@
+//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Utility to parse and verify the content of a `transformFileName` MLIR file
+/// containing a transform dialect specification.
+LogicalResult
+parseTransformModuleFromFile(MLIRContext *context,
+                             llvm::StringRef transformFileName,
+                             OwningOpRef<ModuleOp> &transformModule);
+
+/// Utility to load a transform interpreter `module` from a module that has
+/// already been preloaded in the context.
+/// This mode is useful in cases where explicit parsing of a transform library
+/// from file is expected to be prohibitively expensive.
+/// In such cases, the transform module is expected to be found in the preloaded
+/// library modules of the transform dialect.
+/// Returns null if the module is not found.
+ModuleOp getPreloadedTransformModule(MLIRContext *context);
+
+/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
+/// that is either:
+///   1. nested under `root` (takes precedence).
+///   2. nested under `module`, if not found in `root`.
+/// Reports errors and returns null if no such operation found.
+TransformOpInterface findTransformEntryPoint(
+    Operation *root, ModuleOp module,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+// TODO: Reconsider cloning individual ops rather than forcing users of the
+//       function to clone (or move) `other` in order to improve efficiency.
+//       This might primarily make sense if we can also prune the symbols that
+//       are merged to a subset (such as those that are actually used).
+LogicalResult mergeSymbolsInto(Operation *target,
+                               OwningOpRef<Operation *> other);
+} // namespace detail
+
+/// Standalone util to apply the named sequence `entryPoint` to the payload.
+/// This is done in 3 steps:
+///   1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
+///   calling detail::findTransformEntryPoint.
+///   2. if the entry point is found and not nested under
+///   `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
+///   the `sharedTransformModule`. Note: this may modify the transform IR
+///   embedded with the payload IR.
+///   3. apply the transform IR to the payload IR, relaxing the requirement that
+///   the transform IR is a top-level transform op. We are applying a named
+///   sequence anyway.
+LogicalResult applyTransformNamedSequence(
+    Operation *payload, ModuleOp transformModule,
+    const TransformOptions &options,
+    StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4a9bb2dba7d660c..4f88b8522e54c80 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 // Entry point.
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-transform::applyTransforms(Operation *payloadRoot,
-                           TransformOpInterface transform,
-                           const RaggedArray<MappedValue> &extraMapping,
-                           const TransformOptions &options) {
-#ifndef NDEBUG
-  if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
-      transform->getNumOperands() != 0) {
-    transform->emitError()
-        << "expected transform to start at the top-level transform op";
-    llvm::report_fatal_error("could not run transforms",
-                             /*gen_crash_diag=*/false);
+LogicalResult transform::applyTransforms(
+    Operation *payloadRoot, TransformOpInterface transform,
+    const RaggedArray<MappedValue> &extraMapping,
+    const TransformOptions &options, bool enforceToplevelTransformOp) {
+  if (enforceToplevelTransformOp) {
+    if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+        transform->getNumOperands() != 0) {
+      return transform->emitError()
+             << "expected transform to start at the top-level transform op";
+    }
+  } else if (failed(
+                 detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
+    return failure();
   }
-#endif // NDEBUG
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 3f51ef1088f7af6..8774a8b86fb0d91 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
   InferEffects.cpp
   TransformInterpreterPassBase.cpp
+  TransformInterpreterUtils.cpp
 
   DEPENDS
   MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 764d7e25854206e..ebfd7269f696bbc 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -51,34 +52,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
 constexpr static llvm::StringLiteral
     kTransformDialectTagTransformContainerValue = "transform_container";
 
-/// Utility to parse the content of a `transformFileName` MLIR file containing
-/// a transform dialect specification.
-static LogicalResult
-parseTransformModuleFromFile(MLIRContext *context,
-                             llvm::StringRef transformFileName,
-                             OwningOpRef<ModuleOp> &transformModule) {
-  if (transformFileName.empty()) {
-    LLVM_DEBUG(
-        DBGS() << "no transform file name specified, assuming the transform "
-                  "module is embedded in the IR next to the top-level\n");
-    return success();
-  }
-  // Parse transformFileName content into a ModuleOp.
-  std::string errorMessage;
-  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
-  if (!memoryBuffer) {
-    return emitError(FileLineColLoc::get(
-               StringAttr::get(context, transformFileName), 0, 0))
-           << "failed to open transform file: " << errorMessage;
-  }
-  // Tell sourceMgr about this buffer, the parser will pick it up.
-  llvm::SourceMgr sourceMgr;
-  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
-  transformModule =
-      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
-  return success();
-}
-
 /// Finds the single top-level transform operation with `root` as ancestor.
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
@@ -295,239 +268,6 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Return whether `func1` can be merged into `func2`. For that to work `func1`
-/// has to be a declaration (aka has to be external) and `func2` either has to
-/// be a declaration as well, or it has to be public (otherwise, it wouldn't
-/// be visible by `func1`).
-static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
-  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
-}
-
-/// Merge `func1` into `func2`. The two ops must be inside the same parent op
-/// and mergable according to `canMergeInto`. The function erases `func1` such
-/// that only `func2` exists when the function returns.
-static LogicalResult mergeInto(FunctionOpInterface func1,
-                               FunctionOpInterface func2) {
-  assert(canMergeInto(func1, func2));
-  assert(func1->getParentOp() == func2->getParentOp() &&
-         "expected func1 and func2 to be in the same parent op");
-
-  // Check that function signatures match.
-  if (func1.getFunctionType() != func2.getFunctionType()) {
-    return func1.emitError()
-           << "external definition has a mismatching signature ("
-           << func2.getFunctionType() << ")";
-  }
-
-  // Check and merge argument attributes.
-  MLIRContext *context = func1->getContext();
-  auto *td = context->getLoadedDialect<transform::TransformDialect>();
-  StringAttr consumedName = td->getConsumedAttrName();
-  StringAttr readOnlyName = td->getReadOnlyAttrName();
-  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
-    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
-    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
-    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
-    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
-    if (!isExternalConsumed && !isExternalReadonly) {
-      if (isConsumed)
-        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
-      else if (isReadonly)
-        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
-      continue;
-    }
-
-    if ((isExternalConsumed && !isConsumed) ||
-        (isExternalReadonly && !isReadonly)) {
-      return func1.emitError()
-             << "external definition has mismatching consumption "
-                "annotations for argument #"
-             << i;
-    }
-  }
-
-  // `func1` is the external one, so we can remove it.
-  assert(func1.isExternal());
-  func1->erase();
-
-  return success();
-}
-
-/// Merge all symbols from `other` into `target`. Both ops need to implement the
-/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
-/// modified by this function and might not verify after the function returns.
-/// Upon merging, private symbols may be renamed in order to avoid collisions in
-/// the result. Public symbols may not collide, with the exception of
-/// instances of `SymbolOpInterface`, where collisions are allowed if at least
-/// one of the two is external, in which case the other op preserved (or any one
-/// of the two if both are external).
-// TODO: Reconsider cloning individual ops rather than forcing users of the
-//       function to clone (or move) `other` in order to improve efficiency.
-//       This might primarily make sense if we can also prune the symbols that
-//       are merged to a subset (such as those that are actually used).
-static LogicalResult mergeSymbolsInto(Operation *target,
-                                      OwningOpRef<Operation *> other) {
-  assert(target->hasTrait<OpTrait::SymbolTable>() &&
-         "requires target to implement the 'SymbolTable' trait");
-  assert(other->hasTrait<OpTrait::SymbolTable>() &&
-         "requires target to implement the 'SymbolTable' trait");
-
-  SymbolTable targetSymbolTable(target);
-  SymbolTable otherSymbolTable(*other);
-
-  // Step 1:
-  //
-  // Rename private symbols in both ops in order to resolve conflicts that can
-  // be resolved that way.
-  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
-  // TODO: Do we *actually* need to test in both directions?
-  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
-           SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
-           SmallVector<SymbolTable *, 2>{&otherSymbolTable,
-                                         &targetSymbolTable})) {
-    Operation *symbolTableOp = symbolTable->getOp();
-    for (Operation &op : symbolTableOp->getRegion(0).front()) {
-      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
-      if (!symbolOp)
-        continue;
-      StringAttr name = symbolOp.getNameAttr();
-      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
-
-      // Check if there is a colliding op in the other module.
-      auto collidingOp =
-          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
-      if (!collidingOp)
-        continue;
-
-      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
-
-      // Collisions are fine if both opt are functions and can be merged.
-      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
-          collidingFuncOp =
-              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
-          funcOp && collidingFuncOp) {
-        if (canMergeInto(funcOp, collidingFuncOp) ||
-            canMergeInto(collidingFuncOp, funcOp)) {
-          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
-                                     "will be merged\n");
-          continue;
-        }
-
-        // If they can't be merged, proceed like any other collision.
-        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
-      }
-
-      // Collision can be resolved by renaming if one of the ops is private.
-      auto renameToUnique =
-          [&](SymbolOpInterface op, SymbolOpInterface otherOp,
-              SymbolTable &symbolTable,
-              SymbolTable &otherSymbolTable) -> LogicalResult {
-        LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
-        FailureOr<StringAttr> maybeNewName =
-            symbolTable.renameToUnique(op, {&otherSymbolTable});
-        if (failed(maybeNewName)) {
-          InFlightDiagnostic diag = op->emitError("failed to rename symbol");
-          diag.attachNote(otherOp->getLoc())
-              << "attempted renaming due to collision with this op";
-          return diag;
-        }
-        LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
-                          << "\n");
-        return success();
-      };
-
-      if (symbolOp.isPrivate()) {
-        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
-                                  *otherSymbolTable)))
-          return failure();
-        continue;
-      }
-      if (collidingOp.isPrivate()) {
-        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
-                                  *symbolTable)))
-          return failure();
-        continue;
-      }
-
-      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
-      InFlightDiagnostic diag = symbolOp.emitError()
-                                << "doubly defined symbol @" << name.getValue();
-      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
-      return diag;
-    }
-  }
-
-  // TODO: This duplicates pass infrastructure. We should split this pass into
-  //       several and let the pass infrastructure do the verification.
-  for (auto *op : SmallVector<Operation *>{target, *other}) {
-    if (failed(mlir::verify(op)))
-      return op->emitError() << "failed to verify input op after renaming";
-  }
-
-  // Step 2:
-  //
-  // Move all ops from `other` into target and merge public symbols.
-  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
-  {
-    SmallVector<SymbolOpInterface> opsToMove;
-    for (Operation &op : other->getRegion(0).front()) {
-      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
-        opsToMove.push_back(symbol);
-    }
-
-    for (SymbolOpInterface op : opsToMove) {
-      // Remember potentially colliding op in the target module.
-      auto collidingOp = cast_or_null<SymbolOpInterface>(
-          targetSymbolTable.lookup(op.getNameAttr()));
-
-      // Move op even if we get a collision.
-      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
-      op->moveBefore(&target->getRegion(0).front(),
-                     target->getRegion(0).front().end());
-
-      // If there is no collision, we are done.
-      if (!collidingOp) {
-        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
-        continue;
-      }
-
-      // The two colliding ops must both be functions because we have already
-      // emitted errors otherwise earlier.
-      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
-      auto collidingFuncOp =
-          cast<FunctionOpInterface>(collidingOp.getOperation());
-
-      // Both ops are in the target module now and can be treated symmetrically,
-      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
-      if (!canMergeInto(funcOp, collidingFuncOp)) {
-        std::swap(funcOp, collidingFuncOp);
-      }
-      assert(canMergeInto(funcOp, collidingFuncOp));
-
-      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
-                              << collidingFuncOp.getLoc() << ":\n"
-                              << collidingFuncOp << "\n");
-
-      // Update symbol table. This works with or without the previous `swap`.
-      targetSymbolTable.remove(funcOp);
-      targetSymbolTable.insert(collidingFuncOp);
-      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
-
-      // Do the actual merging.
-      if (failed(mergeInto(funcOp, collidingFuncOp))) {
-        return failure();
-      }
-    }
-  }
-
-  if (failed(mlir::verify(target)))
-    return target->emitError()
-           << "failed to verify target op after merging symbols";
-
-  LLVM_DEBUG(DBGS() << "done merging ops\n");
-  return success();
-}
-
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -595,9 +335,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(
-            mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
-                             transformLibraryModule->get()->clone())))
+    if (failed(detail::mergeSymbolsInto(
+            SymbolTable::getNearestSymbolTable(transformRoot),
+            transformLibraryModule->get()->clone())))
       return emitError(transformRoot->getLoc(),
                        "failed to merge library symbols into transform root");
   }
@@ -683,8 +423,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   OwningOpRef<ModuleOp> moduleFromFile;
   {
     auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
-    if (failed(parseTransformModuleFromFile(context, transformFileName,
-                                            moduleFromFile)))
+    if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
+                                                    moduleFromFile)))
       return emitError(loc) << "failed to parse transform module";
     if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
       return emitError(loc) << "failed to verify transform module";
@@ -701,8 +441,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   for (const std::string &libraryFileName : libraryFileNames) {
     OwningOpRef<ModuleOp> parsedLibrary;
     auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
-    if (failed(parseTransformModuleFromFile(context, libraryFileName,
-                                            parsedLibrary)))
+    if (failed(detail::parseTransformModuleFromFile(context, libraryFileName,
+                                                    parsedLibrary)))
       return emitError(loc) << "failed to parse transform library module";
     if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
       return emitError(loc) << "failed to verify transform library module";
@@ -741,8 +481,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     IRRewriter rewriter(context);
     // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
     for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
-      if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
-                                  std::move(parsedLibrary))))
+      if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(),
+                                          std::move(parsedLibrary))))
         return mergedParsedLibraries->emitError()
                << "failed to verify merged transform module";
     }
@@ -751,8 +491,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   // Use parsed libaries to resolve symbols in shared transform module or return
   // as separate library module.
   if (sharedTransformModule && *sharedTransformModule) {
-    if (failed(mergeSymbolsInto(sharedTransformModule->get(),
-                                std::move(mergedParsedLibraries))))
+    if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(),
+                                        std::move(mergedParsedLibraries))))
       return (*sharedTransformModule)->emitError()
              << "failed to merge symbols from library files "
                 "into shared transform module";
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
new file mode 100644
index 000000000000000..1a6ebdd16232e8a
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -0,0 +1,337 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformModuleFromFile(
+    MLIRContext *context, llvm::StringRef transformFileName,
+    OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to open transform file: " << errorMessage;
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
+  auto preloadedLibraryRange =
+      context->getOrLoadDialect<transform::TransformDialect>()
+          ->getLibraryModules();
+  if (!preloadedLibraryRange.empty())
+    return *preloadedLibraryRange.begin();
+  return ModuleOp();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
+                                           StringRef entryPoint) {
+  SmallVector<Operation *, 2> l{root};
+  if (module)
+    l.push_back(module);
+  for (Operation *op : l) {
+    transform::TransformOpInterface transform = nullptr;
+    op->walk<WalkOrder::PreOrder>(
+        [&](transform::NamedSequenceOp namedSequenceOp) {
+          if (namedSequenceOp.getSymName() == entryPoint) {
+            transform = cast<transform::TransformOpInterface>(
+                namedSequenceOp.getOperation());
+            return WalkResult::interrupt();
+          }
+          return WalkResult::advance();
+        });
+    if (transform)
+      return transform;
+  }
+  auto diag = root->emitError()
+              << "could not find a nested named sequence with name: "
+              << entryPoint;
+  return nullptr;
+}
+
+/// Return whether `func1` can be merged into `func2`. For that to work `func1`
+/// has to be a declaration (aka has to be external) and `func2` either has to
+/// be a declaration as well, or it has to be public (otherwise, it wouldn't
+/// be visible by `func1`).
+static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+static LogicalResult mergeInto(FunctionOpInterface func1,
+                               FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
+
+  // Check and merge argument attributes.
+  MLIRContext *context = func1->getContext();
+  auto *td = context->getLoadedDialect<transform::TransformDialect>();
+  StringAttr consumedName = td->getConsumedAttrName();
+  StringAttr readOnlyName = td->getReadOnlyAttrName();
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
+      continue;
+    }
+
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
+    }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
+
+  return success();
+}
+
+LogicalResult
+transform::detail::mergeSymbolsInto(Operation *target,
+                                    OwningOpRef<Operation *> other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
+
+  // Step 1:
+  //
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  // TODO: Do we *actually* need to test in both directions?
+  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *, 2>{&otherSymbolTable,
+                                         &targetSymbolTable})) {
+    Operation *symbolTableOp = symbolTable->getOp();
+    for (Operation &op : symbolTableOp->getRegion(0).front()) {
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
+      StringAttr name = symbolOp.getNameAttr();
+      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
+
+      // Check if there is a colliding op in the other module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+      if (!collidingOp)
+        continue;
+
+      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
+
+      // Collisions are fine if both opt are functions and can be merged.
+      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+          collidingFuncOp =
+              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+          funcOp && collidingFuncOp) {
+        if (canMergeInto(funcOp, collidingFuncOp) ||
+            canMergeInto(collidingFuncOp, funcOp)) {
+          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+                                     "will be merged\n");
+          continue;
+        }
+
+        // If they can't be merged, proceed like any other collision.
+        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+      }
+
+      // Collision can be resolved by renaming if one of the ops is private.
+      auto renameToUnique =
+          [&](SymbolOpInterface op, SymbolOpInterface otherOp,
+              SymbolTable &symbolTable,
+              SymbolTable &otherSymbolTable) -> LogicalResult {
+        LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+        FailureOr<StringAttr> maybeNewName =
+            symbolTable.renameToUnique(op, {&otherSymbolTable});
+        if (failed(maybeNewName)) {
+          InFlightDiagnostic diag = op->emitError("failed to rename symbol");
+          diag.attachNote(otherOp->getLoc())
+              << "attempted renaming due to collision with this op";
+          return diag;
+        }
+        LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
+                          << "\n");
+        return success();
+      };
+
+      if (symbolOp.isPrivate()) {
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable)))
+          return failure();
+        continue;
+      }
+      if (collidingOp.isPrivate()) {
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable)))
+          return failure();
+        continue;
+      }
+      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+      InFlightDiagnostic diag = symbolOp.emitError()
+                                << "doubly defined symbol @" << name.getValue();
+      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+      return diag;
+    }
+  }
+
+  // TODO: This duplicates pass infrastructure. We should split this pass into
+  //       several and let the pass infrastructure do the verification.
+  for (auto *op : SmallVector<Operation *>{target, *other}) {
+    if (failed(mlir::verify(op)))
+      return op->emitError() << "failed to verify input op after renaming";
+  }
+
+  // Step 2:
+  //
+  // Move all ops from `other` into target and merge public symbols.
+  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+  {
+    SmallVector<SymbolOpInterface> opsToMove;
+    for (Operation &op : other->getRegion(0).front()) {
+      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+        opsToMove.push_back(symbol);
+    }
+
+    for (SymbolOpInterface op : opsToMove) {
+      // Remember potentially colliding op in the target module.
+      auto collidingOp = cast_or_null<SymbolOpInterface>(
+          targetSymbolTable.lookup(op.getNameAttr()));
+
+      // Move op even if we get a collision.
+      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
+      op->moveBefore(&target->getRegion(0).front(),
+                     target->getRegion(0).front().end());
+
+      // If there is no collision, we are done.
+      if (!collidingOp) {
+        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+        continue;
+      }
+
+      // The two colliding ops must both be functions because we have already
+      // emitted errors otherwise earlier.
+      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+      auto collidingFuncOp =
+          cast<FunctionOpInterface>(collidingOp.getOperation());
+
+      // Both ops are in the target module now and can be treated symmetrically,
+      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+      if (!canMergeInto(funcOp, collidingFuncOp)) {
+        std::swap(funcOp, collidingFuncOp);
+      }
+      assert(canMergeInto(funcOp, collidingFuncOp));
+
+      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+                              << collidingFuncOp.getLoc() << ":\n"
+                              << collidingFuncOp << "\n");
+
+      // Update symbol table. This works with or without the previous `swap`.
+      targetSymbolTable.remove(funcOp);
+      targetSymbolTable.insert(collidingFuncOp);
+      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+      // Do the actual merging.
+      if (failed(mergeInto(funcOp, collidingFuncOp))) {
+        return failure();
+      }
+    }
+  }
+
+  if (failed(mlir::verify(target)))
+    return target->emitError()
+           << "failed to verify target op after merging symbols";
+
+  LLVM_DEBUG(DBGS() << "done merging ops\n");
+  return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+    Operation *payload, ModuleOp transformModule,
+    const TransformOptions &options, StringRef entryPoint) {
+  Operation *transformRoot =
+      detail::findTransformEntryPoint(payload, transformModule, entryPoint);
+  if (!transformRoot)
+    return failure();
+
+  // `transformModule` may not be modified.
+  OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
+  if (transformModule && !transformModule->isAncestor(transformRoot)) {
+    if (failed(detail::mergeSymbolsInto(
+            SymbolTable::getNearestSymbolTable(transformRoot),
+            std::move(clonedTransformModule))))
+      return failure();
+  }
+
+  // Apply the transform to the IR, do not enforce top-level constraints.
+  RaggedArray<MappedValue> noExtraMappings;
+  return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
+                         noExtraMappings, options,
+                         /*enforceToplevelTransformOp=*/false);
+}
diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
index 1fecd21221c91c8..89238a0bdae16ea 100644
--- a/mlir/unittests/Dialect/Transform/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRTransformDialectTests
   BuildOnlyExtensionTest.cpp
+  Preload.cpp
 )
 target_link_libraries(MLIRTransformDialectTests
   PRIVATE
   MLIRFuncDialect
+  MLIRTestTransformDialect
   MLIRTransformDialect
+  MLIRTransformDialectTransforms
 )
diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp
new file mode 100644
index 000000000000000..d3c3044e0e0f776
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/Preload.cpp
@@ -0,0 +1,92 @@
+//===- Preload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry &registry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+    transform.yield
+  }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+})MLIR";
+
+TEST(Preload, ContextPreloadConstructedLibrary) {
+  registerPassManagerCLOptions();
+
+  MLIRContext context;
+  auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+  DialectRegistry registry;
+  ::test::registerTestTransformDialectExtension(registry);
+  registry.applyExtensions(&context);
+  ParserConfig parserConfig(&context);
+
+  OwningOpRef<ModuleOp> inputModule =
+      parseSourceString<ModuleOp>(input, parserConfig, "<input>");
+  EXPECT_TRUE(inputModule) << "failed to parse input module";
+
+  OwningOpRef<ModuleOp> transformLibrary =
+      parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
+  EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
+  dialect->registerLibraryModule(std::move(transformLibrary));
+
+  ModuleOp retrievedTransformLibrary =
+      transform::detail::getPreloadedTransformModule(&context);
+  EXPECT_TRUE(retrievedTransformLibrary)
+      << "failed to retrieve transform module";
+
+  transform::TransformOpInterface entryPoint =
+      transform::detail::findTransformEntryPoint(inputModule->getOperation(),
+                                                 retrievedTransformLibrary);
+  EXPECT_TRUE(entryPoint) << "failed to find entry point";
+
+  OwningOpRef<Operation *> clonedTransformModule(
+      retrievedTransformLibrary->clone());
+  LogicalResult res = transform::detail::mergeSymbolsInto(
+      inputModule->getOperation(), std::move(clonedTransformModule));
+  EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
+
+  transform::TransformOptions options;
+  res = transform::applyTransformNamedSequence(
+      inputModule->getOperation(), retrievedTransformLibrary, options);
+  EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
+}



More information about the Mlir-commits mailing list