[Mlir-commits] [mlir] [mlir][transform] Allow passing various library files to interpreter. (PR #67120)

Ingo Müller llvmlistbot at llvm.org
Fri Sep 29 07:11:48 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/67120

>From 680516754397efcbbfc9fc1d75dad91846cd45ff Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 15:10:30 +0000
Subject: [PATCH 1/4] [mlir][transform] Fix handling of transitive include in
 interpreter.

Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR implements a kind of linker for transform scripts to solve this
problem. The linker merges all symbols of the library module into the
main module before interpreting the latter. Symbols whose names collide
are handled as follows: (1) if they are both functions (in the sense of
`FunctionOpInterface`) with compatible signatures, one is external, and
the other one is public, then they are merged; (2) of one of them is
private, that one is renamed; and (3) an error is raised otherwise.
---
 .../TransformInterpreterPassBase.cpp          | 283 ++++++++++++++----
 ...ter-external-symbol-decl-and-schedule.mlir |   6 +
 ...erpreter-external-symbol-decl-invalid.mlir |  23 +-
 ...test-interpreter-external-symbol-decl.mlir |  61 +++-
 ...terpreter-external-symbol-def-invalid.mlir |  14 +
 .../test-interpreter-external-symbol-def.mlir |  43 ++-
 6 files changed, 352 insertions(+), 78 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..09b3601d59b3c32 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -302,77 +302,236 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
-
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
-    }
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function. Upon merging, private symbols may be renamed in
+/// order to avoid collisions in the result. Public symbols may not collide,
+/// with the exception of `SymbolInterfaceOp`s, where collisions are allowed if
+/// at least one of the two is external, in which case the other op preserved
+/// (or one of the two if both are external). The `target` op might not verify
+/// after this function returns.
+// XXX: Make `other` argument an `OwningOpRef`?
+static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  MLIRContext *context = other->getContext();
+  auto consumedName = StringAttr::get(
+      context, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName = StringAttr::get(
+      context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  int uniqueId = 0;
+
+  auto canBeMerged = [](FunctionOpInterface func1, FunctionOpInterface func2) {
+    return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+    ;
+  };
+
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  for (auto [symbolTableOp, otherSymbolTableOp] :
+       llvm::zip(SmallVector<Operation *>{target, other},
+                 SmallVector<Operation *>{other, target})) {
+    SymbolTable symbolTable(symbolTableOp); // XXX: build only once
+    SymbolTable otherSymbolTable(otherSymbolTableOp);
+    for (Operation &op : symbolTableOp->getRegion(0).front()) {
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
+      StringAttr name = symbolOp.getNameAttr();
+      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
+
+      // Check if there is a colliding op in the other module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(otherSymbolTable.lookup(name));
+      if (!collidingOp)
+        continue;
+
+      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
+
+      // Collisions are fine if both opt are functions and can be merged.
+      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+          collidingFuncOp =
+              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+          funcOp && collidingFuncOp) {
+        if (canBeMerged(funcOp, collidingFuncOp) ||
+            canBeMerged(collidingFuncOp, funcOp)) {
+          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+                                     "will be merged\n");
+          continue;
+        }
+
+        // If they can't be merged, proceed like any other collision.
+        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+      }
+
+      /// Rename `op` inside `symbolTableOp` with symbol table `symbolTable`
+      /// to avoid a collision with `otherOp`.
+      auto renameToUnique =
+          [&uniqueId =
+               uniqueId](SymbolOpInterface op, SymbolOpInterface otherOp,
+                         Operation *symbolTableOp, SymbolTable &symbolTable,
+                         SymbolTable &otherSymbolTable) -> LogicalResult {
+        assert(SymbolTable::getNearestSymbolTable(op) == symbolTableOp &&
+               "expected 'op' to be inside of 'symbolTableOp'");
+        MLIRContext *context = op->getContext();
+
+        // Determine new name that is unique in both symbol tables.
+        StringAttr newName;
+        {
+          SmallString<64> prefix = op.getNameAttr().getValue();
+          prefix.push_back('_');
+          while (true) {
+            newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+            if (!symbolTable.lookup(newName) &&
+                !otherSymbolTable.lookup(newName)) {
+              break;
+            }
+          }
+        }
+
+        // Apply renaming.
+        LLVM_DEBUG(llvm::dbgs()
+                   << ", renaming to @" << newName.getValue() << "\n");
+        if (failed(SymbolTable::replaceAllSymbolUses(op, newName,
+                                                     symbolTableOp))) {
+          InFlightDiagnostic diag =
+              emitError(op->getLoc(), Twine("failed to rename symbol to @") +
+                                          newName.getValue());
+          diag.attachNote(otherOp->getLoc())
+              << "renaming due to collision with this op";
+          return diag;
+        }
+        op.setName(newName); // XXX: Why is this necessary? Why does
+                             // SymbolTable::renameAllSymbolUses not do it?
+        return success();
+      };
+
+      // Collision can be resolved if one of the ops is private.
+      if (symbolOp.isPrivate()) {
+        if (failed(renameToUnique(symbolOp, collidingOp, symbolTableOp,
+                                  symbolTable, otherSymbolTable)))
+          return failure();
+        continue;
+      }
+      if (collidingOp.isPrivate()) {
+        if (failed(renameToUnique(collidingOp, symbolOp, otherSymbolTableOp,
+                                  symbolTable, otherSymbolTable)))
+          return failure();
+        continue;
+      }
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
-      continue;
+      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+      InFlightDiagnostic diag =
+          emitError(symbolOp->getLoc(),
+                    Twine("doubly defined symbol @") + name.getValue());
+      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+      return diag;
     }
+  }
+
+  for (auto *op : SmallVector<Operation *>{target, other}) {
+    if (failed(mlir::verify(op)))
+      return emitError(op->getLoc(),
+                       "failed to verify input op after renaming");
+  }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
-    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
-      return symbolFunc.emitError()
-             << "external definition has a mismatching signature ("
-             << externalSymbolFunc.getFunctionType() << ")";
+  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+  {
+    SmallVector<SymbolOpInterface> opsToMove;
+    for (Operation &op : other->getRegion(0).front()) {
+      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+        opsToMove.push_back(symbol);
     }
 
-    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-      bool isExternalConsumed =
-          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isExternalReadonly =
-          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      if (!isExternalConsumed && !isExternalReadonly) {
-        if (isConsumed)
-          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
-        else if (isReadonly)
-          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+    SymbolTable symbolTable(target);
+    for (SymbolOpInterface op : opsToMove) {
+      // Remember potentially colliding op in the target module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(symbolTable.lookup(op.getNameAttr()));
+
+      // Move op even if we get a collision.
+      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
+      op->moveAfter(&target->getRegion(0).front(),
+                    target->getRegion(0).front().begin());
+
+      // If there is no collision, we are done.
+      if (!collidingOp) {
+        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
         continue;
       }
 
-      if ((isExternalConsumed && !isConsumed) ||
-          (isExternalReadonly && !isReadonly)) {
+      // We now have a collision that we resolve through merging. The merging
+      // may bring the symbol table out of date but we don't need to access the
+      // table for that symbol anymore.
+
+      // The two colliding ops must bot be functions because we have already
+      // emitted errors otherwise earlier.
+      auto symbolFunc = cast<FunctionOpInterface>(op.getOperation());
+      auto externalSymbolFunc =
+          cast<FunctionOpInterface>(collidingOp.getOperation());
+
+      // Both ops are in the target module now and can be treated symmetrically,
+      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+      if (!canBeMerged(symbolFunc, externalSymbolFunc))
+        std::swap(symbolFunc, externalSymbolFunc);
+      assert(canBeMerged(symbolFunc, externalSymbolFunc));
+
+      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+                              << externalSymbolFunc.getLoc() << ":\n"
+                              << externalSymbolFunc << "\n");
+
+      // Check that function signatures match.
+      // XXX: Do that check earlier?
+      if (symbolFunc.getFunctionType() !=
+          externalSymbolFunc.getFunctionType()) {
         return symbolFunc.emitError()
-               << "external definition has mismatching consumption annotations "
-                  "for argument #"
-               << i;
+               << "external definition has a mismatching signature ("
+               << externalSymbolFunc.getFunctionType() << ")";
       }
-    }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
-    symbol->erase();
+      // Check and merge argument attributes.
+      for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+        bool isExternalConsumed =
+            externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+        bool isExternalReadonly =
+            externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+        bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+        bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+        if (!isExternalConsumed && !isExternalReadonly) {
+          if (isConsumed)
+            externalSymbolFunc.setArgAttr(i, consumedName,
+                                          UnitAttr::get(context));
+          else if (isReadonly)
+            externalSymbolFunc.setArgAttr(i, readOnlyName,
+                                          UnitAttr::get(context));
+          continue;
+        }
+
+        if ((isExternalConsumed && !isConsumed) ||
+            (isExternalReadonly && !isReadonly)) {
+          return symbolFunc.emitError()
+                 << "external definition has mismatching consumption "
+                    "annotations for argument #"
+                 << i;
+        }
+      }
+
+      // `funcOp` is the external one, so we can remove it.
+      assert(symbolFunc.isExternal());
+      symbolFunc->erase();
+    }
   }
 
+  if (failed(mlir::verify(target)))
+    return emitError(target->getLoc(),
+                     "failed to verify target op after merging symbols");
+
+  LLVM_DEBUG(DBGS() << "done merging ops\n");
   return success();
 }
 
@@ -438,8 +597,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     libraryModule->get())))
+    if (failed(
+            mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
+                             libraryModule->get())))
       return failure();
   }
 
@@ -499,8 +659,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     return success();
 
   if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+    if (failed(mergeSymbolsInto(module->get(), parsedLibrary.get())))
       return failure();
   } else {
     libraryModule =
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..dd8d141e994da0e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -11,4 +11,10 @@
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
 module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..7452deb39b6c18d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,16 +1,16 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
-// The definition of the @foo named sequence is provided in another file. It
+// The definition of the @print_message named sequence is provided in another file. It
 // will be included because of the pass option.
 
 module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
-  transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
+  transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
-    include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+    include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
   }
 }
 
@@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} {
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
   }
 }
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{doubly defined symbol @print_message}}
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..d14c55e6b7be882 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,29 +4,68 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+// XXX: This currently fails.
+// RoooooN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RoooooN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// The definition of the @print_message named sequence is provided in another
+// file. It will be included because of the pass option. Repeated application of
+// the same pass, with or without the library option, should not be a problem.
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
 module attributes {transform.with_named_sequence} {
-  // CHECK: transform.named_sequence @foo
-  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
-  transform.named_sequence private @foo(!transform.any_op {transform.readonly})
+  // CHECK: transform.named_sequence @print_message(
+  // CHECK: transform.include @private_helper
+  transform.named_sequence private @print_message(!transform.any_op {transform.readonly})
+
+  // These ops collide with ops from the other module before or after renaming.
+  transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding (without suffix)" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_0" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_1(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_1" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_3" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @colliding_4(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_4" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_5(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "internal colliding_5" : !transform.any_op
+    transform.yield
+  }
 
-  // CHECK: transform.named_sequence @unannotated
+  // CHECK: transform.named_sequence @unannotated(
   // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
-  transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
+  transform.named_sequence @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.any_op):
-    include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
     include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_1 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> ()
   }
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir
new file mode 100644
index 000000000000000..1d9ef1dbead63c6
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  // expected-note @below {{previously defined here}}
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
+    transform.test_consume_operand %arg0 : !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..66f0f1f62683b7e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,11 +1,42 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
+  transform.named_sequence private @private_helper(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield
   }
 
+  // These ops collide with ops from the other module before or after renaming.
+  transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding (without suffix)" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_0" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_2(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_2" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_3" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence private @colliding_4(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_4" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @colliding_5(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "external colliding_5" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.include @private_helper failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
   transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
     transform.test_consume_operand %arg0 : !transform.any_op
     transform.yield
@@ -15,4 +46,14 @@ module attributes {transform.with_named_sequence} {
     transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op
     transform.yield
   }
+
+  transform.named_sequence @symbol_user(%arg0: !transform.any_op {transform.readonly}) {
+    transform.include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_2 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
 }

>From 5daa3f7981915c6eae52ce83ff409e52da1d3bb3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 29 Sep 2023 07:37:51 +0000
Subject: [PATCH 2/4] Several minor and major clean-ups:

* Move all private functions of the CPP file into anonymous namespace.
* Remove test with second interpreter pass that reloads the library. I
  think that this shouldn't be possible.
* Factor out `renameToUnique`, `canMergeInto`, and `mergeInto` into
  proper functions.
* Use a single symbol table per input op and update it correctly
  whenever symbols or ops change.
* Make `other` arg an `OwningOpRef` and clone the arguments where
  necessary.
* Improve comments.
---
 .../TransformInterpreterPassBase.cpp          | 285 ++++++++++--------
 ...test-interpreter-external-symbol-decl.mlir |  10 +-
 2 files changed, 163 insertions(+), 132 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 09b3601d59b3c32..3dd172335160fb3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -50,6 +50,8 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
 constexpr static llvm::StringLiteral
     kTransformDialectTagTransformContainerValue = "transform_container";
 
+namespace {
+
 /// Utility to parse the content of a `transformFileName` MLIR file containing
 /// a transform dialect specification.
 static LogicalResult
@@ -302,42 +304,142 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+                             SymbolTable &symbolTable,
+                             SymbolTable &otherSymbolTable, int &uniqueId) {
+  assert(symbolTable.lookup(op.getNameAttr()) == op &&
+         "symbol table does not contain op");
+  assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+         "other symbol table does not contain other op");
+
+  // Determine new name that is unique in both symbol tables.
+  StringAttr oldName = op.getNameAttr();
+  StringAttr newName;
+  {
+    MLIRContext *context = op->getContext();
+    SmallString<64> prefix = oldName.getValue();
+    prefix.push_back('_');
+    while (true) {
+      newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+      if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+        break;
+      }
+    }
+  }
+
+  // Apply renaming.
+  LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+  if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+    InFlightDiagnostic diag =
+        emitError(op->getLoc(),
+                  Twine("failed to rename symbol to @") + newName.getValue());
+    diag.attachNote(otherOp->getLoc())
+        << "attempted renaming due to collision with this op";
+    return diag;
+  }
+
+  // Change the symbol in the op itself and update the symbol table.
+  symbolTable.remove(op);
+  SymbolTable::setSymbolName(op, newName);
+  symbolTable.insert(op);
+
+  assert(symbolTable.lookup(newName) == op &&
+         "symbol table does not resolve to renamed op");
+  assert(symbolTable.lookup(oldName) == nullptr &&
+         "symbol table still resolves old name");
+
+  return success();
+}
+
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  MLIRContext *context = func1->getContext();
+  auto consumedName = StringAttr::get(
+      context, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName = StringAttr::get(
+      context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
+
+  // Check and merge argument attributes.
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
+      continue;
+    }
+
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
+    }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
+
+  return success();
+}
+
 /// Merge all symbols from `other` into `target`. Both ops need to implement the
 /// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
-/// modified by this function. Upon merging, private symbols may be renamed in
-/// order to avoid collisions in the result. Public symbols may not collide,
-/// with the exception of `SymbolInterfaceOp`s, where collisions are allowed if
-/// at least one of the two is external, in which case the other op preserved
-/// (or one of the two if both are external). The `target` op might not verify
-/// after this function returns.
-// XXX: Make `other` argument an `OwningOpRef`?
-static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+                                      OwningOpRef<Operation *> other) {
   assert(target->hasTrait<OpTrait::SymbolTable>() &&
          "requires target to implement the 'SymbolTable' trait");
   assert(other->hasTrait<OpTrait::SymbolTable>() &&
          "requires target to implement the 'SymbolTable' trait");
 
-  MLIRContext *context = other->getContext();
-  auto consumedName = StringAttr::get(
-      context, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName = StringAttr::get(
-      context, transform::TransformDialect::kArgReadOnlyAttrName);
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
 
   int uniqueId = 0;
 
-  auto canBeMerged = [](FunctionOpInterface func1, FunctionOpInterface func2) {
-    return func1.isExternal() && (func2.isPublic() || func2.isExternal());
-    ;
-  };
-
+  // Step 1:
+  //
   // Rename private symbols in both ops in order to resolve conflicts that can
   // be resolved that way.
   LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
-  for (auto [symbolTableOp, otherSymbolTableOp] :
-       llvm::zip(SmallVector<Operation *>{target, other},
-                 SmallVector<Operation *>{other, target})) {
-    SymbolTable symbolTable(symbolTableOp); // XXX: build only once
-    SymbolTable otherSymbolTable(otherSymbolTableOp);
+  for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
+    Operation *symbolTableOp = symbolTable->getOp();
     for (Operation &op : symbolTableOp->getRegion(0).front()) {
       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
       if (!symbolOp)
@@ -347,7 +449,7 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
 
       // Check if there is a colliding op in the other module.
       auto collidingOp =
-          cast_or_null<SymbolOpInterface>(otherSymbolTable.lookup(name));
+          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
       if (!collidingOp)
         continue;
 
@@ -358,8 +460,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
           collidingFuncOp =
               dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
           funcOp && collidingFuncOp) {
-        if (canBeMerged(funcOp, collidingFuncOp) ||
-            canBeMerged(collidingFuncOp, funcOp)) {
+        if (canMergeInto(funcOp, collidingFuncOp) ||
+            canMergeInto(collidingFuncOp, funcOp)) {
           LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
                                      "will be merged\n");
           continue;
@@ -369,58 +471,16 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
         LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
       }
 
-      /// Rename `op` inside `symbolTableOp` with symbol table `symbolTable`
-      /// to avoid a collision with `otherOp`.
-      auto renameToUnique =
-          [&uniqueId =
-               uniqueId](SymbolOpInterface op, SymbolOpInterface otherOp,
-                         Operation *symbolTableOp, SymbolTable &symbolTable,
-                         SymbolTable &otherSymbolTable) -> LogicalResult {
-        assert(SymbolTable::getNearestSymbolTable(op) == symbolTableOp &&
-               "expected 'op' to be inside of 'symbolTableOp'");
-        MLIRContext *context = op->getContext();
-
-        // Determine new name that is unique in both symbol tables.
-        StringAttr newName;
-        {
-          SmallString<64> prefix = op.getNameAttr().getValue();
-          prefix.push_back('_');
-          while (true) {
-            newName = StringAttr::get(context, prefix + Twine(uniqueId++));
-            if (!symbolTable.lookup(newName) &&
-                !otherSymbolTable.lookup(newName)) {
-              break;
-            }
-          }
-        }
-
-        // Apply renaming.
-        LLVM_DEBUG(llvm::dbgs()
-                   << ", renaming to @" << newName.getValue() << "\n");
-        if (failed(SymbolTable::replaceAllSymbolUses(op, newName,
-                                                     symbolTableOp))) {
-          InFlightDiagnostic diag =
-              emitError(op->getLoc(), Twine("failed to rename symbol to @") +
-                                          newName.getValue());
-          diag.attachNote(otherOp->getLoc())
-              << "renaming due to collision with this op";
-          return diag;
-        }
-        op.setName(newName); // XXX: Why is this necessary? Why does
-                             // SymbolTable::renameAllSymbolUses not do it?
-        return success();
-      };
-
       // Collision can be resolved if one of the ops is private.
       if (symbolOp.isPrivate()) {
-        if (failed(renameToUnique(symbolOp, collidingOp, symbolTableOp,
-                                  symbolTable, otherSymbolTable)))
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable, uniqueId)))
           return failure();
         continue;
       }
       if (collidingOp.isPrivate()) {
-        if (failed(renameToUnique(collidingOp, symbolOp, otherSymbolTableOp,
-                                  symbolTable, otherSymbolTable)))
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable, uniqueId)))
           return failure();
         continue;
       }
@@ -434,12 +494,15 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
     }
   }
 
-  for (auto *op : SmallVector<Operation *>{target, other}) {
+  for (auto *op : SmallVector<Operation *>{target, *other}) {
     if (failed(mlir::verify(op)))
       return emitError(op->getLoc(),
                        "failed to verify input op after renaming");
   }
 
+  // Step 2:
+  //
+  // Move all ops from `other` into target and merge public symbols.
   LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
   {
     SmallVector<SymbolOpInterface> opsToMove;
@@ -448,11 +511,10 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
         opsToMove.push_back(symbol);
     }
 
-    SymbolTable symbolTable(target);
     for (SymbolOpInterface op : opsToMove) {
       // Remember potentially colliding op in the target module.
-      auto collidingOp =
-          cast_or_null<SymbolOpInterface>(symbolTable.lookup(op.getNameAttr()));
+      auto collidingOp = cast_or_null<SymbolOpInterface>(
+          targetSymbolTable.lookup(op.getNameAttr()));
 
       // Move op even if we get a collision.
       LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
@@ -465,65 +527,34 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
         continue;
       }
 
-      // We now have a collision that we resolve through merging. The merging
-      // may bring the symbol table out of date but we don't need to access the
-      // table for that symbol anymore.
-
-      // The two colliding ops must bot be functions because we have already
+      // The two colliding ops must both be functions because we have already
       // emitted errors otherwise earlier.
-      auto symbolFunc = cast<FunctionOpInterface>(op.getOperation());
-      auto externalSymbolFunc =
+      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+      auto collidingFuncOp =
           cast<FunctionOpInterface>(collidingOp.getOperation());
 
       // Both ops are in the target module now and can be treated symmetrically,
       // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
-      if (!canBeMerged(symbolFunc, externalSymbolFunc))
-        std::swap(symbolFunc, externalSymbolFunc);
-      assert(canBeMerged(symbolFunc, externalSymbolFunc));
+      if (!canMergeInto(funcOp, collidingFuncOp)) {
+        std::swap(funcOp, collidingFuncOp);
+      }
+      assert(canMergeInto(funcOp, collidingFuncOp));
 
       LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
-                              << externalSymbolFunc.getLoc() << ":\n"
-                              << externalSymbolFunc << "\n");
-
-      // Check that function signatures match.
-      // XXX: Do that check earlier?
-      if (symbolFunc.getFunctionType() !=
-          externalSymbolFunc.getFunctionType()) {
-        return symbolFunc.emitError()
-               << "external definition has a mismatching signature ("
-               << externalSymbolFunc.getFunctionType() << ")";
-      }
+                              << collidingFuncOp.getLoc() << ":\n"
+                              << collidingFuncOp << "\n");
 
-      // Check and merge argument attributes.
-      for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-        bool isExternalConsumed =
-            externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-        bool isExternalReadonly =
-            externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-        bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-        bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-        if (!isExternalConsumed && !isExternalReadonly) {
-          if (isConsumed)
-            externalSymbolFunc.setArgAttr(i, consumedName,
-                                          UnitAttr::get(context));
-          else if (isReadonly)
-            externalSymbolFunc.setArgAttr(i, readOnlyName,
-                                          UnitAttr::get(context));
-          continue;
-        }
+      // Update symbol table. This works with or without the previous `swap`.
+      targetSymbolTable.remove(funcOp);
+      targetSymbolTable.insert(collidingFuncOp);
+      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
 
-        if ((isExternalConsumed && !isConsumed) ||
-            (isExternalReadonly && !isReadonly)) {
-          return symbolFunc.emitError()
-                 << "external definition has mismatching consumption "
-                    "annotations for argument #"
-                 << i;
-        }
+      // Do the actual merging.
+      if (failed(mergeInto(funcOp, collidingFuncOp))) {
+        return failure();
       }
 
-      // `funcOp` is the external one, so we can remove it.
-      assert(symbolFunc.isExternal());
-      symbolFunc->erase();
+      assert(succeeded(mlir::verify(target)));
     }
   }
 
@@ -535,6 +566,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) {
   return success();
 }
 
+} // namespace
+
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -599,7 +632,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     }
     if (failed(
             mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
-                             libraryModule->get())))
+                             libraryModule->get()->clone())))
       return failure();
   }
 
@@ -659,7 +692,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     return success();
 
   if (module && *module) {
-    if (failed(mergeSymbolsInto(module->get(), parsedLibrary.get())))
+    if (failed(mergeSymbolsInto(module->get(), std::move(parsedLibrary))))
       return failure();
   } else {
     libraryModule =
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index d14c55e6b7be882..a9083fe3e70788a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,13 +4,11 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// XXX: This currently fails.
-// RoooooN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RoooooN:             --verify-diagnostics --split-input-file | FileCheck %s
-
 // The definition of the @print_message named sequence is provided in another
-// file. It will be included because of the pass option. Repeated application of
-// the same pass, with or without the library option, should not be a problem.
+// file. It will be included because of the pass option. Subsequent application
+// of the same pass works but only without the library file (since the first
+// application loads external symbols and loading them again woul make them
+// clash).
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 

>From 02149a5a72c94ac2a61d444136c2f8786898ad49 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 29 Sep 2023 13:55:20 +0000
Subject: [PATCH 3/4] Minor fixes.

* Use `moveBefore` instead of `moveAfter` in order to work on empty
  targets as well.
* Do not verify the target after moving each op, since the last op may
  use symbols of ops that still have to be moved.
---
 .../Transform/Transforms/TransformInterpreterPassBase.cpp   | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 3dd172335160fb3..424de11a46503b7 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -518,8 +518,8 @@ static LogicalResult mergeSymbolsInto(Operation *target,
 
       // Move op even if we get a collision.
       LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
-      op->moveAfter(&target->getRegion(0).front(),
-                    target->getRegion(0).front().begin());
+      op->moveBefore(&target->getRegion(0).front(),
+                     target->getRegion(0).front().end());
 
       // If there is no collision, we are done.
       if (!collidingOp) {
@@ -553,8 +553,6 @@ static LogicalResult mergeSymbolsInto(Operation *target,
       if (failed(mergeInto(funcOp, collidingFuncOp))) {
         return failure();
       }
-
-      assert(succeeded(mlir::verify(target)));
     }
   }
 

>From 4315e21bcfc0f4fe5e48f010c6592947275ae332 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 22 Sep 2023 12:05:38 +0000
Subject: [PATCH 4/4] [mlir][transform] Allow passing various library files to
 interpreter.

The transfrom interpreter accepts an argument to a "library" file with
named sequences. This patch exteneds this functionality such that (1)
several such individual files are accepted and (2) folders can be passed
in, in which all `*.mlir` files are loaded.
---
 .../Transforms/TransformInterpreterPassBase.h |  43 +++--
 .../TransformInterpreterPassBase.cpp          | 180 +++++++++++++-----
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir |   2 +-
 ...ter-external-symbol-decl-and-schedule.mlir |   4 +-
 ...-interpreter-external-symbol-decl-dir.mlir |  22 +++
 ...erpreter-external-symbol-decl-invalid.mlir |   5 +-
 ...test-interpreter-external-symbol-decl.mlir |   4 +-
 .../definitions-self-contained.mlir}          |   0
 .../definitions-with-unresolved.mlir          |  10 +
 .../Dialect/Transform/match_batch_matmul.mlir |   2 +-
 .../Dialect/Transform/match_matmul.mlir       |   2 +-
 .../TestTransformDialectInterpreter.cpp       |  17 +-
 12 files changed, 220 insertions(+), 71 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
 rename mlir/test/Dialect/Transform/{test-interpreter-external-symbol-def.mlir => test-interpreter-library/definitions-self-contained.mlir} (100%)
 create mode 100644 mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..5d021da90e9a594 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -33,7 +33,8 @@ namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
 LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
     std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +49,8 @@ LogicalResult interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -62,9 +64,14 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
-///   - transformLibraryFileName: if non-empty, the name of the file containing
+///   - transformLibraryFileNames: if non-empty, the names of files containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to replace declarations and must be
+///     unique within all files provided by this and the next option.
+///   - transformLibraryDirNames: if non-empty, the name of directories
+///     containing definitions of external symbols referenced in the transform
+///     script. These definitions will be used to replace declarations and must
+///     be unique within all files provided by this and the previous option.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -115,17 +122,30 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
-    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
+#define REQUIRE_PASS_LIST_OPTION(NAME)                                         \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::ListOption<std::string>>,                                      \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_LIST_OPTION(transformLibraryFileNames);
+    REQUIRE_PASS_LIST_OPTION(transformLibraryDirNames);
+
+#undef REQUIRE_PASS_LIST_OPTION
+
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    StringRef transformLibraryFileName =
-        static_cast<Concrete *>(this)->transformLibraryFileName;
+    ArrayRef<std::string> transformLibraryFileNames =
+        static_cast<Concrete *>(this)->transformLibraryFileNames;
+    ArrayRef<std::string> transformLibraryDirNames =
+        static_cast<Concrete *>(this)->transformLibraryDirNames;
     return detail::interpreterBaseInitializeImpl(
-        context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule,
+        context, transformFileName, transformLibraryFileNames,
+        transformLibraryDirNames, sharedTransformModule, transformLibraryModule,
         [this](OpBuilder &builder, Location loc) {
           return static_cast<Concrete *>(this)->constructTransformModule(
               builder, loc);
@@ -159,8 +179,9 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
             op, pass->getArgument(), sharedTransformModule,
             transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->transformLibraryFileName, pass->debugPayloadRootTag,
-            pass->debugTransformRootTag, binaryName)) ||
+            pass->transformLibraryFileNames, pass->transformLibraryDirNames,
+            pass->debugPayloadRootTag, pass->debugTransformRootTag,
+            binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
     }
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 424de11a46503b7..df0f25556a832e3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -163,15 +164,24 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
-               const Pass::Option<std::string> &transformLibraryFileName,
+               const Pass::ListOption<std::string> &transformLibraryFileNames,
+               const Pass::ListOption<std::string> &transformLibraryDirNames,
                StringRef binaryName) {
-  std::string transformLibraryOption = "";
-  if (!transformLibraryFileName.empty()) {
-    transformLibraryOption =
-        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
-                      transformLibraryFileName.getValue())
-            .str();
+  std::string transformLibraryOptions = "";
+  {
+    llvm::raw_string_ostream optionStream(transformLibraryOptions);
+    if (!transformLibraryFileNames.empty()) {
+      optionStream << " " << transformLibraryFileNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryFileNames, optionStream, ",");
+      optionStream << "'";
+    }
+    if (!transformLibraryDirNames.empty()) {
+      optionStream << " " << transformLibraryDirNames.getArgStr() << "='";
+      llvm::interleave(transformLibraryDirNames, optionStream, ",");
+      optionStream << "'";
+    }
   }
+
   os << llvm::formatv(
       "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
@@ -182,7 +192,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      transformLibraryOption, binaryName);
+      transformLibraryOptions, binaryName);
   return os;
 }
 
@@ -202,7 +212,8 @@ void saveReproToTempFile(
     llvm::raw_ostream &os, Operation *target, Operation *transform,
     StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
@@ -229,7 +240,8 @@ void saveReproToTempFile(
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
                  debugPayloadRootTag, debugTransformRootTag,
-                 transformLibraryFileName, binaryName)
+                 transformLibraryFileNames, transformLibraryDirNames,
+                 binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -240,7 +252,8 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -281,10 +294,10 @@ static void performOptionalDebugActions(
 
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
-    printReproCall(llvm::dbgs() << "cat <<EOF | ",
-                   root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag,
-                   transformLibraryFileName, binaryName)
+    printReproCall(
+        llvm::dbgs() << "cat <<EOF | ", root->getName().getStringRef(),
+        passName, debugPayloadRootTag, debugTransformRootTag,
+        transformLibraryFileNames, transformLibraryDirNames, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -294,7 +307,8 @@ static void performOptionalDebugActions(
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
                         debugPayloadRootTag, debugTransformRootTag,
-                        transformLibraryFileName, binaryName);
+                        transformLibraryFileNames, transformLibraryDirNames,
+                        binaryName);
   });
 
   // Remove temporary attributes if they were set.
@@ -573,7 +587,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
-    const Pass::Option<std::string> &transformLibraryFileName,
+    const Pass::ListOption<std::string> &transformLibraryFileNames,
+    const Pass::ListOption<std::string> &transformLibraryDirNames,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -631,7 +646,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     if (failed(
             mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
                              libraryModule->get()->clone())))
-      return failure();
+      return emitError(
+          transformRoot->getLoc(),
+          "failed to replace declarations with definitions in transform root");
   }
 
   // Step 4
@@ -640,7 +657,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              transformLibraryFileName, binaryName);
+                              transformLibraryFileNames,
+                              transformLibraryDirNames, binaryName);
 
   // Step 5
   // ------
@@ -651,50 +669,120 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    ArrayRef<std::string> transformLibraryFileNames,
+    ArrayRef<std::string> transformLibraryDirNames,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
-    return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
-    return failure();
+  auto unknownLoc = UnknownLoc::get(context);
 
-  OwningOpRef<ModuleOp> parsedLibrary;
-  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
-    return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
-    return failure();
+  // Parse module from file.
+  OwningOpRef<ModuleOp> moduleFromFile;
+  {
+    auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                            moduleFromFile)))
+      return emitError(loc, "failed to parse transform module");
+    if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
+      return emitError(loc, "failed to verify transform module");
+  }
+
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  libraryFileNames.append(transformLibraryFileNames.begin(),
+                          transformLibraryFileNames.end());
+
+  for (const std::string &dirName : transformLibraryDirNames) {
+    LLVM_DEBUG(DBGS() << "Opening files in '" << dirName << "':\n");
+
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(dirName, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
+                          << "'\n");
+        continue;
+      }
+
+      if (!StringRef(fileName).endswith(".mlir")) {
+        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
+                          << "' because it does not end with '.mlir'\n");
+        continue;
+      }
+
+      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
+      libraryFileNames.push_back(fileName);
+    }
+
+    if (ec)
+      return emitError(unknownLoc, Twine("error while opening files in '") +
+                                       dirName + "': " + ec.message());
+  }
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(parseTransformModuleFromFile(context, libraryFileName,
+                                            parsedLibrary)))
+      return emitError(loc, "failed to parse transform library module");
+    if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+      return emitError(loc, "failed to verify transform library module");
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Build shared transform module.
+  if (moduleFromFile) {
+    sharedTransformModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
   } else if (moduleBuilder) {
     // TODO: better location story.
-    auto location = UnknownLoc::get(context);
     auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
-        ModuleOp::create(location, "__transform"));
+        ModuleOp::create(unknownLoc, "__transform"));
 
     OpBuilder b(context);
     b.setInsertionPointToEnd(localModule->get().getBody());
-    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+    if (std::optional<LogicalResult> result = moduleBuilder(b, unknownLoc)) {
       if (failed(*result))
-        return failure();
-      module = std::move(localModule);
+        return emitError(unknownLoc,
+                         "failed to create shared transform module");
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (parsedLibraries.empty())
     return success();
 
-  if (module && *module) {
-    if (failed(mergeSymbolsInto(module->get(), std::move(parsedLibrary))))
-      return failure();
+  // Merge parsed libraries into one module.
+  // TODO: better location story.
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(unknownLoc, "__transform");
+  {
+    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                         UnitAttr::get(context));
+    IRRewriter rewriter(context);
+    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+      if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
+                                  std::move(parsedLibrary))))
+        return emitError(unknownLoc,
+                         "failed to verify merged transform module");
+    }
+  }
+
+  // Use parsed libaries to resolve symbols in shared transform module or return
+  // as separate library module.
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(mergeSymbolsInto(sharedTransformModule->get(),
+                                std::move(mergedParsedLibraries))))
+      return emitError(unknownLoc, "failed to replace declarations with "
+                                   "definitions in shared transform module");
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(mergedParsedLibraries));
   }
   return success();
 }
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c219cfe08ac4d6a..c1253cdb92d5002 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,7 +2,7 @@
 
 // RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
 
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
+// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-names=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
 // RUN:   -test-transform-dialect-erase-schedule -cse \
 // RUN: | FileCheck %s
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index dd8d141e994da0e..9aa40c52decd83e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics
 
 // The external transform script has a declaration to the named sequence @foo,
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
new file mode 100644
index 000000000000000..c6ab7b68a18e5a0
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-dir.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library})" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-dir-names=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK: transform.named_sequence @print_message
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index 7452deb39b6c18d..5c832cf720b50ce 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
 // The definition of the @print_message named sequence is provided in another file. It
@@ -8,6 +8,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
   transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
+  // expected-error @below {{failed to replace declarations with definitions in transform root}}
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
     include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -32,6 +33,7 @@ module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
   transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
 
+  // expected-error @below {{failed to replace declarations with definitions in transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -47,6 +49,7 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 
+  // expected-error @below {{failed to replace declarations with definitions in transform root}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
     include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index a9083fe3e70788a..cec74d220552236 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-names=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @print_message named sequence is provided in another
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
similarity index 100%
rename from mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
rename to mlir/test/Dialect/Transform/test-interpreter-library/definitions-self-contained.mlir
diff --git a/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
new file mode 100644
index 000000000000000..eeecf7b474cbc6e
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-library/definitions-with-unresolved.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
+
+  transform.named_sequence @baz(%arg0: !transform.any_op) {
+    transform.include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
index 73bc243ad76060d..c7c055c7567e6f3 100644
--- a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index f164a3d1bd99dd0..3e06f2ca0895841 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-names=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @_match_matmul_like(
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..28e5256385efaea 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -161,7 +161,8 @@ class TestTransformDialectInterpreterPass
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
-            transformFileName, transformLibraryFileName, debugPayloadRootTag,
+            transformFileName, transformLibraryFileNames,
+            transformLibraryDirNames, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -216,12 +217,16 @@ class TestTransformDialectInterpreterPass
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
-  Option<std::string> transformLibraryFileName{
-      *this, "transform-library-file-name", llvm::cl::init(""),
+  ListOption<std::string> transformLibraryFileNames{
+      *this, "transform-library-file-names", llvm::cl::ZeroOrMore,
+      llvm::cl::desc("Optional filenames containing transform dialect symbol "
+                     "definitions to be injected into the transform module.")};
+  ListOption<std::string> transformLibraryDirNames{
+      *this, "transform-library-dir-names", llvm::cl::ZeroOrMore,
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
-
+          "Optional directories containing transform dialect symbol "
+          "definitions to be injected into the transform module. All '.mlir' "
+          "files rooted under this directory will be loaded.")};
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),
       llvm::cl::desc("test the generation of the transform module during pass "



More information about the Mlir-commits mailing list