[Mlir-commits] [mlir] 7876899 - [mlir][transform] Fix handling of transitive include in interpreter. (#67560)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 01:57:02 PDT 2023


Author: Ingo Müller
Date: 2023-10-06T10:56:57+02:00
New Revision: 787689943d027b062274f22097e7f30b0a52bd5b

URL: https://github.com/llvm/llvm-project/commit/787689943d027b062274f22097e7f30b0a52bd5b
DIFF: https://github.com/llvm/llvm-project/commit/787689943d027b062274f22097e7f30b0a52bd5b.diff

LOG: [mlir][transform] Fix handling of transitive include in interpreter. (#67560)

Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR implements a kind of linker for transform scripts to solve this
problem. The linker merges all symbols of the library module into the
main module before interpreting the latter. Symbols whose names collide
are handled as follows: (1) if they are both functions (in the sense of
`FunctionOpInterface`) with compatible signatures, one is external, and
the other one is public, then they are merged; (2) of one of them is
private, that one is renamed; and (3) an error is raised otherwise.

One consequence of this change is that the loading of the library files
in the interpreter pass is not idempotent anymore, i.e., subsequent 
interpreter passes cannot (and need not) load the same library files again
since would lead to doubly defined symbols.

Added: 
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
    mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 70a76ab9670f907..f28205a25507025 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -43,6 +43,15 @@ def Transform_Dialect : Dialect {
     constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
         "transform.readonly";
 
+    /// Names of the attributes indicating whether an argument of an external
+    /// transform dialect symbol is consumed or only read.
+    StringAttr getConsumedAttrName() const {
+      return StringAttr::get(getContext(), kArgConsumedAttrName);
+    }
+    StringAttr getReadOnlyAttrName() const {
+      return StringAttr::get(getContext(), kArgReadOnlyAttrName);
+    }
+
     template <typename DataTy>
     const DataTy &getExtraData() const {
       return *static_cast<const DataTy *>(

diff  --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 6102417ceda1a7b..a6f0dddebd7eacf 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -62,9 +62,11 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the name of the file containing
-///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///   - transformLibraryFileName: if non-empty, the module in this file will be
+///     merged into the main transform script run by the interpreter before
+///     execution. This allows to provide definitions for external functions
+///     used in the main script. Other public symbols in the library module may
+///     lead to collisions with public symbols in the main script.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +87,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 /// as template arguments. They are *not* expected to to implement `initialize`
 /// or `runOnOperation`. They *are* expected to call the copy constructor of
 /// this class in their copy constructors, short of which the file-based
-/// transform dialect script injection facility will become nonoperational.
+/// transform dialect script injection facility will become non-operational.
 ///
 /// Concrete passes may implement the `runBeforeInterpreter` and
 /// `runAfterInterpreter` to customize the behavior of the pass.

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 33427788a075ed3..7f21f22eba951e3 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -55,6 +55,23 @@ class SymbolTable {
   /// after insertion as attribute.
   StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
 
+  /// Renames the given op or the op refered to by the given name to the given
+  /// new name and updates the symbol table and all usages of the symbol
+  /// accordingly. Fails if the updating of the usages fails.
+  LogicalResult rename(StringAttr from, StringAttr to);
+  LogicalResult rename(Operation *op, StringAttr to);
+  LogicalResult rename(StringAttr from, StringRef to);
+  LogicalResult rename(Operation *op, StringRef to);
+
+  /// Renames the given op or the op refered to by the given name to the a name
+  /// that is unique within this and the provided other symbol tables and
+  /// updates the symbol table and all usages of the symbol accordingly. Returns
+  /// the new name or failure if the renaming fails.
+  FailureOr<StringAttr> renameToUnique(StringAttr from,
+                                       ArrayRef<SymbolTable *> others);
+  FailureOr<StringAttr> renameToUnique(Operation *op,
+                                       ArrayRef<SymbolTable *> others);
+
   /// Return the name of the attribute used for symbol names.
   static StringRef getSymbolAttrName() { return "sym_name"; }
 

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 23640c92457a89d..68a735e7ef8e0b3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -161,17 +161,9 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
-               const Pass::Option<std::string> &transformLibraryFileName,
                StringRef binaryName) {
-  std::string transformLibraryOption = "";
-  if (!transformLibraryFileName.empty()) {
-    transformLibraryOption =
-        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
-                      transformLibraryFileName.getValue())
-            .str();
-  }
   os << llvm::formatv(
-      "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
+      "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
       debugPayloadRootTag.empty()
           ? StringRef(kTransformDialectTagPayloadRootValue)
@@ -180,14 +172,15 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      transformLibraryOption, binaryName);
+      binaryName);
   return os;
 }
 
 /// Prints the module rooted at `root` to `os` and appends
 /// `transformContainer` if it is not nested in `root`.
-llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root,
-                                       Operation *transform) {
+static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os,
+                                              Operation *root,
+                                              Operation *transform) {
   root->print(os);
   if (!root->isAncestor(transform))
     transform->print(os);
@@ -196,12 +189,13 @@ llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root,
 
 /// Saves the payload and the transform IR into a temporary file and reports
 /// the file name to `os`.
-void saveReproToTempFile(
-    llvm::raw_ostream &os, Operation *target, Operation *transform,
-    StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
-    const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
-    StringRef binaryName) {
+static void
+saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
+                    Operation *transform, StringRef passName,
+                    const Pass::Option<std::string> &debugPayloadRootTag,
+                    const Pass::Option<std::string> &debugTransformRootTag,
+                    const Pass::Option<std::string> &transformLibraryFileName,
+                    StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
 
@@ -226,8 +220,7 @@ void saveReproToTempFile(
 
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
-                 debugPayloadRootTag, debugTransformRootTag,
-                 transformLibraryFileName, binaryName)
+                 debugPayloadRootTag, debugTransformRootTag, binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -281,8 +274,7 @@ static void performOptionalDebugActions(
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
     printReproCall(llvm::dbgs() << "cat <<EOF | ",
                    root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag,
-                   transformLibraryFileName, binaryName)
+                   debugPayloadRootTag, debugTransformRootTag, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -302,77 +294,236 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
+/// Return whether `func1` can be merged into `func2`. For that to work `func1`
+/// has to be a declaration (aka has to be external) and `func2` either has to
+/// be a declaration as well, or it has to be public (otherwise, it wouldn't
+/// be visible by `func1`).
+static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
 
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
-    }
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+static LogicalResult mergeInto(FunctionOpInterface func1,
+                               FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+  // Check and merge argument attributes.
+  MLIRContext *context = func1->getContext();
+  auto *td = context->getLoadedDialect<transform::TransformDialect>();
+  StringAttr consumedName = td->getConsumedAttrName();
+  StringAttr readOnlyName = td->getReadOnlyAttrName();
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
-    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
-      return symbolFunc.emitError()
-             << "external definition has a mismatching signature ("
-             << externalSymbolFunc.getFunctionType() << ")";
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
     }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
 
-    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-      bool isExternalConsumed =
-          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isExternalReadonly =
-          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      if (!isExternalConsumed && !isExternalReadonly) {
-        if (isConsumed)
-          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
-        else if (isReadonly)
-          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+  return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+// TODO: Reconsider cloning individual ops rather than forcing users of the
+//       function to clone (or move) `other` in order to improve efficiency.
+//       This might primarily make sense if we can also prune the symbols that
+//       are merged to a subset (such as those that are actually used).
+static LogicalResult mergeSymbolsInto(Operation *target,
+                                      OwningOpRef<Operation *> other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
+
+  // Step 1:
+  //
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  // TODO: Do we *actually* need to test in both directions?
+  for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *, 2>{&otherSymbolTable,
+                                         &targetSymbolTable})) {
+    Operation *symbolTableOp = symbolTable->getOp();
+    for (Operation &op : symbolTableOp->getRegion(0).front()) {
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
         continue;
+      StringAttr name = symbolOp.getNameAttr();
+      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
+
+      // Check if there is a colliding op in the other module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+      if (!collidingOp)
+        continue;
+
+      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
+
+      // Collisions are fine if both opt are functions and can be merged.
+      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+          collidingFuncOp =
+              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+          funcOp && collidingFuncOp) {
+        if (canMergeInto(funcOp, collidingFuncOp) ||
+            canMergeInto(collidingFuncOp, funcOp)) {
+          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+                                     "will be merged\n");
+          continue;
+        }
+
+        // If they can't be merged, proceed like any other collision.
+        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
       }
 
-      if ((isExternalConsumed && !isConsumed) ||
-          (isExternalReadonly && !isReadonly)) {
-        return symbolFunc.emitError()
-               << "external definition has mismatching consumption annotations "
-                  "for argument #"
-               << i;
+      // Collision can be resolved by renaming if one of the ops is private.
+      auto renameToUnique =
+          [&](SymbolOpInterface op, SymbolOpInterface otherOp,
+              SymbolTable &symbolTable,
+              SymbolTable &otherSymbolTable) -> LogicalResult {
+        LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+        FailureOr<StringAttr> maybeNewName =
+            symbolTable.renameToUnique(op, {&otherSymbolTable});
+        if (failed(maybeNewName)) {
+          InFlightDiagnostic diag = op->emitError("failed to rename symbol");
+          diag.attachNote(otherOp->getLoc())
+              << "attempted renaming due to collision with this op";
+          return diag;
+        }
+        LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
+                          << "\n");
+        return success();
+      };
+
+      if (symbolOp.isPrivate()) {
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable)))
+          return failure();
+        continue;
       }
+      if (collidingOp.isPrivate()) {
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable)))
+          return failure();
+        continue;
+      }
+
+      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+      InFlightDiagnostic diag = symbolOp.emitError()
+                                << "doubly defined symbol @" << name.getValue();
+      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+      return diag;
     }
+  }
+
+  // TODO: This duplicates pass infrastructure. We should split this pass into
+  //       several and let the pass infrastructure do the verification.
+  for (auto *op : SmallVector<Operation *>{target, *other}) {
+    if (failed(mlir::verify(op)))
+      return op->emitError() << "failed to verify input op after renaming";
+  }
+
+  // Step 2:
+  //
+  // Move all ops from `other` into target and merge public symbols.
+  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+  {
+    SmallVector<SymbolOpInterface> opsToMove;
+    for (Operation &op : other->getRegion(0).front()) {
+      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+        opsToMove.push_back(symbol);
+    }
+
+    for (SymbolOpInterface op : opsToMove) {
+      // Remember potentially colliding op in the target module.
+      auto collidingOp = cast_or_null<SymbolOpInterface>(
+          targetSymbolTable.lookup(op.getNameAttr()));
+
+      // Move op even if we get a collision.
+      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
+      op->moveBefore(&target->getRegion(0).front(),
+                     target->getRegion(0).front().end());
+
+      // If there is no collision, we are done.
+      if (!collidingOp) {
+        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+        continue;
+      }
+
+      // The two colliding ops must both be functions because we have already
+      // emitted errors otherwise earlier.
+      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+      auto collidingFuncOp =
+          cast<FunctionOpInterface>(collidingOp.getOperation());
+
+      // Both ops are in the target module now and can be treated symmetrically,
+      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+      if (!canMergeInto(funcOp, collidingFuncOp)) {
+        std::swap(funcOp, collidingFuncOp);
+      }
+      assert(canMergeInto(funcOp, collidingFuncOp));
+
+      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+                              << collidingFuncOp.getLoc() << ":\n"
+                              << collidingFuncOp << "\n");
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
-    symbol->erase();
+      // Update symbol table. This works with or without the previous `swap`.
+      targetSymbolTable.remove(funcOp);
+      targetSymbolTable.insert(collidingFuncOp);
+      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+      // Do the actual merging.
+      if (failed(mergeInto(funcOp, collidingFuncOp))) {
+        return failure();
+      }
+    }
   }
 
+  if (failed(mlir::verify(target)))
+    return target->emitError()
+           << "failed to verify target op after merging symbols";
+
+  LLVM_DEBUG(DBGS() << "done merging ops\n");
   return success();
 }
 
@@ -443,8 +594,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     transformLibraryModule->get())))
+    if (failed(
+            mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
+                             transformLibraryModule->get()->clone())))
       return failure();
   }
 
@@ -506,8 +658,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     return success();
 
   if (sharedTransformModule && *sharedTransformModule) {
-    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
-                                     parsedLibraryModule.get())))
+    if (failed(mergeSymbolsInto(sharedTransformModule->get(),
+                                std::move(parsedLibraryModule))))
       return failure();
   } else {
     transformLibraryModule =

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2494cb7086f0d7d..8ff1859e1383fcb 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -218,6 +218,79 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
   return getSymbolName(symbol);
 }
 
+LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
+  Operation *op = lookup(from);
+  return rename(op, to);
+}
+
+LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
+  StringAttr from = getNameIfSymbol(op);
+
+  assert(from && "expected valid 'name' attribute");
+  assert(op->getParentOp() == symbolTableOp &&
+         "expected this operation to be inside of the operation with this "
+         "SymbolTable");
+  assert(lookup(from) == op && "current name does not resolve to op");
+  assert(lookup(to) == nullptr && "new name already exists");
+
+  if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
+    return failure();
+
+  // Remove op with old name, change name, add with new name. The order is
+  // important here due to how `remove` and `insert` rely on the op name.
+  remove(op);
+  setSymbolName(op, to);
+  insert(op);
+
+  assert(lookup(to) == op && "new name does not resolve to renamed op");
+  assert(lookup(from) == nullptr && "old name still exists");
+
+  return success();
+}
+
+LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
+  auto toAttr = StringAttr::get(getOp()->getContext(), to);
+  return rename(from, toAttr);
+}
+
+LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
+  auto toAttr = StringAttr::get(getOp()->getContext(), to);
+  return rename(op, toAttr);
+}
+
+FailureOr<StringAttr>
+SymbolTable::renameToUnique(StringAttr oldName,
+                            ArrayRef<SymbolTable *> others) {
+
+  // Determine new name that is unique in all symbol tables.
+  StringAttr newName;
+  {
+    MLIRContext *context = oldName.getContext();
+    SmallString<64> prefix = oldName.getValue();
+    int uniqueId = 0;
+    prefix.push_back('_');
+    while (true) {
+      newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+      auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
+      if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
+        break;
+      }
+    }
+  }
+
+  // Apply renaming.
+  if (failed(rename(oldName, newName)))
+    return failure();
+  return newName;
+}
+
+FailureOr<StringAttr>
+SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) {
+  StringAttr from = getNameIfSymbol(op);
+  assert(from && "expected valid 'name' attribute");
+  return renameToUnique(from, others);
+}
+
 /// Returns the name of the given symbol operation.
 StringAttr SymbolTable::getSymbolName(Operation *symbol) {
   StringAttr name = getNameIfSymbol(symbol);

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..dd8d141e994da0e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -11,4 +11,10 @@
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
 module {}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..7452deb39b6c18d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,16 +1,16 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
-// The definition of the @foo named sequence is provided in another file. It
+// The definition of the @print_message named sequence is provided in another file. It
 // will be included because of the pass option.
 
 module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
-  transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
+  transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
-    include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+    include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
   }
 }
 
@@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} {
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
   }
 }
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{doubly defined symbol @print_message}}
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..7d0837abebde32c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,29 +4,68 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
-
-// The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// The definition of the @print_message named sequence is provided in another
+// file. It will be included because of the pass option. Subsequent application
+// of the same pass works but only without the library file (since the first
+// application loads external symbols and loading them again woul make them
+// clash).
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
 module attributes {transform.with_named_sequence} {
-  // CHECK: transform.named_sequence @foo
-  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
-  transform.named_sequence private @foo(!transform.any_op {transform.readonly})
+  // CHECK-DAG: transform.named_sequence @print_message(
+  // CHECK-DAG: transform.include @private_helper
+  transform.named_sequence private @print_message(!transform.any_op {transform.readonly})
+
+  // These ops collide with ops from the other module before or after renaming.
+  transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding (without suffix)" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_0" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_1(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_1" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_3" : !transform.any_op
+    transform.yield
+  }
+  // This symbol is public and thus can't be renamed.
+  // CHECK-DAG: transform.named_sequence @colliding_4(
+  transform.named_sequence @colliding_4(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_4" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_5(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_5" : !transform.any_op
+    transform.yield
+  }
 
-  // CHECK: transform.named_sequence @unannotated
-  // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
-  transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
+  // CHECK-DAG: transform.named_sequence @unannotated(
+  // CHECK-DAG: test_print_remark_at_operand %{{.*}}, "unannotated"
+  transform.named_sequence @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.any_op):
-    include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
     include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_1 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> ()
   }
 }

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir
new file mode 100644
index 000000000000000..1d9ef1dbead63c6
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  // expected-note @below {{previously defined here}}
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
+    transform.test_consume_operand %arg0 : !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..66f0f1f62683b7e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,11 +1,42 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
+  transform.named_sequence private @private_helper(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield
   }
 
+  // These ops collide with ops from the other module before or after renaming.
+  transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding (without suffix)" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_0" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_2(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_2" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_3" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_4(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_4" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @colliding_5(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_5" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.include @private_helper failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
     transform.test_consume_operand %arg0 : !transform.any_op
     transform.yield
@@ -15,4 +46,14 @@ module attributes {transform.with_named_sequence} {
     transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op
     transform.yield
   }
+
+  transform.named_sequence @symbol_user(%arg0: !transform.any_op {transform.readonly}) {
+    transform.include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_2 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
 }

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..578d9abe4a56ecb 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -218,9 +218,9 @@ class TestTransformDialectInterpreterPass
           "select the container of the top-level transform op.")};
   Option<std::string> transformLibraryFileName{
       *this, "transform-library-file-name", llvm::cl::init(""),
-      llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
+      llvm::cl::desc("Optional name of a file with a module that should be "
+                     "merged into the transform module to provide the "
+                     "definitions of external named sequences.")};
 
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),


        


More information about the Mlir-commits mailing list