[Mlir-commits] [mlir] 6e44f11 - [mlir] provide a base class for transform interpreter passes

Alex Zinenko llvmlistbot at llvm.org
Fri Feb 3 06:12:41 PST 2023


Author: Alex Zinenko
Date: 2023-02-03T14:12:31Z
New Revision: 6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1

URL: https://github.com/llvm/llvm-project/commit/6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1
DIFF: https://github.com/llvm/llvm-project/commit/6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1.diff

LOG: [mlir] provide a base class for transform interpreter passes

The transform dialect infrastructure does not provide a default
interpreter pass and instead expects users to create their own to ensure
all relevant extensions and dependent dialects are loaded. Provide a
base class for implementing such passes that includes the additional
facilities for debugging and is aware of the multithreaded nature of
pass execution.

Reviewed By: pifon2a, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D142729

Added: 
    mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
    mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
    mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
    mlir/test/Dialect/Transform/test-interpreter-debug.mlir
    mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
    mlir/test/Dialect/Transform/test-interpreter-external.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
    mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/CMakeLists.txt
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 063b6dec7dc97..bae5d4568a556 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -25,7 +25,9 @@ class TransformOpInterface;
 /// TransformState.
 class TransformOptions {
 public:
-  TransformOptions() {}
+  TransformOptions() = default;
+  TransformOptions(const TransformOptions &) = default;
+  TransformOptions &operator=(const TransformOptions &) = default;
 
   /// Requests computationally expensive checks of the transform and payload IR
   /// well-formedness to be performed before each transformation. In particular,
@@ -200,7 +202,8 @@ class TransformState {
       assert(res.second && "the region scope is already present");
       (void)res;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-      assert(state.regionStack.back()->isProperAncestor(&region) &&
+      assert(((state.regionStack.size() == 1 && !state.regionStack.back()) ||
+              state.regionStack.back()->isProperAncestor(&region)) &&
              "scope started at a non-nested region");
       state.regionStack.push_back(&region);
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

diff  --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
new file mode 100644
index 0000000000000..1b1ad6a74ecd5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -0,0 +1,164 @@
+//===- TransformInterpreterPassBase.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Base class with shared implementation for transform dialect interpreter
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Template-free implementation of TransformInterpreterPassBase::initialize.
+LogicalResult
+interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
+                              std::shared_ptr<OwningOpRef<ModuleOp>> &module);
+
+/// Template-free implementation of
+/// TransformInterpreterPassBase::runOnOperation.
+LogicalResult interpreterBaseRunOnOperationImpl(
+    Operation *target, StringRef passName,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const TransformOptions &options,
+    const Pass::Option<std::string> &transformFileName,
+    const Pass::Option<std::string> &debugPayloadRootTag,
+    const Pass::Option<std::string> &debugTransformRootTag,
+    StringRef binaryName);
+} // namespace detail
+
+/// Base class for transform dialect interpreter passes that can consume and
+/// dump transform dialect scripts in separate files. The pass is controlled by
+/// three string options:
+///
+///   - transformFileName: if non-empty, the name of the file containing the
+///     transform script. If empty, `debugTransformRootTag` is considered or the
+///     pass root operation must contain a single top-level transform op that
+///     will be interpreted.
+///   - debugPayloadRootTag: if non-empty, the value of the attribute named
+///     `kTransformDialectTagAttrName` indicating the single op that is
+///     considered the payload root of the transform interpreter; otherwise, the
+///     root operation of the pass is used.
+///   - debugTransformRootTag: if non-empty, the value of the attribute named
+///     `kTransformDialectTagAttrName` indicating the single top-level transform
+///     op contained in the payload root to be used as the entry point by the
+///     transform interpreter; mutually exclusive with `transformFileName`.
+///
+/// The pass runs the transform dialect interpreter as directed by the options.
+/// It also provides the mechanism to dump reproducers into stderr
+/// (-debug-only=transform-dialect-dump-repro) or into a temporary file
+/// (-debug-only=transform-dialect-save-repro) that can be used with this
+/// pass in a standalone mode.
+///
+/// Concrete passes must derive from this class instead of their generated base
+/// class (or PassWrapper), and supply themselves and the generated base class
+/// as template arguments. They are *not* expected to to implement `initialize`
+/// or `runOnOperation`. They *are* expected to call the copy constructor of
+/// this class in their copy constructors, short of which the file-based
+/// transform dialect script injection facility will become nonoperational.
+///
+/// Concrete passes may implement the `runBeforeInterpreter` and
+/// `runAfterInterpreter` to customize the behavior of the pass.
+template <typename Concrete, template <typename> typename GeneratedBase>
+class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
+public:
+  explicit TransformInterpreterPassBase(
+      const TransformOptions &options = TransformOptions())
+      : options(options) {}
+
+  TransformInterpreterPassBase(const TransformInterpreterPassBase &pass) {
+    sharedTransformModule = pass.sharedTransformModule;
+    options = pass.options;
+  }
+
+  static StringLiteral getBinaryName() { return "mlir-opt"; }
+
+  LogicalResult initialize(MLIRContext *context) override {
+
+#define REQUIRE_PASS_OPTION(NAME)                                              \
+  static_assert(                                                               \
+      std::is_same_v<                                                          \
+          std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>,  \
+          Pass::Option<std::string>>,                                          \
+      "required " #NAME " string pass option is missing")
+
+    REQUIRE_PASS_OPTION(transformFileName);
+    REQUIRE_PASS_OPTION(debugPayloadRootTag);
+    REQUIRE_PASS_OPTION(debugTransformRootTag);
+
+#undef REQUIRE_PASS_OPTION
+
+    StringRef transformFileName =
+        static_cast<Concrete *>(this)->transformFileName;
+    return detail::interpreterBaseInitializeImpl(context, transformFileName,
+                                                 sharedTransformModule);
+  }
+
+  /// Hook for passes to run additional logic in the pass before the
+  /// interpreter. If failure is returned, the pass fails and the interpreter is
+  /// not run.
+  LogicalResult runBeforeInterpreter(Operation *) { return success(); }
+
+  /// Hook for passes to run additional logic in the pass after the interpreter.
+  /// Only runs if everything succeeded before. If failure is returned, the pass
+  /// fails.
+  LogicalResult runAfterInterpreter(Operation *) { return success(); }
+
+  void runOnOperation() override {
+    auto *pass = static_cast<Concrete *>(this);
+    Operation *op = pass->getOperation();
+    StringRef binaryName = Concrete::getBinaryName();
+    if (failed(pass->runBeforeInterpreter(op)) ||
+        failed(detail::interpreterBaseRunOnOperationImpl(
+            op, pass->getArgument(), sharedTransformModule,
+            /*extraMappings=*/{}, options, pass->transformFileName,
+            pass->debugPayloadRootTag, pass->debugTransformRootTag,
+            binaryName)) ||
+        failed(pass->runAfterInterpreter(op))) {
+      return pass->signalPassFailure();
+    }
+  }
+
+protected:
+  /// Transform interpreter options.
+  TransformOptions options;
+
+  /// Returns a read-only reference to shared transform module.
+  const std::shared_ptr<OwningOpRef<ModuleOp>> &
+  getSharedTransformModule() const {
+    return sharedTransformModule;
+  }
+
+private:
+  /// The separate transform module to be used for transformations, shared
+  /// across multiple instances of the pass if it is applied in parallel to
+  /// avoid potentially expensive cloning. MUST NOT be modified after the pass
+  /// has been initialized.
+  std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5ecc1f47573c6..5eb0eedb6d7af 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/OwningOpRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
@@ -828,6 +830,25 @@ transform::applyTransforms(Operation *payloadRoot,
   }
 #endif // NDEBUG
 
+  // If the transform dialect may use PDL which may modify the IR, clone it
+  // before use to avoid concurrent modification in case this is being called
+  // from pass instances running concurrently with a shared transform script.
+  auto *pdlDialect =
+      transform->getContext()->getLoadedDialect<pdl::PDLDialect>();
+  bool hasPDL = transform
+                    .walk([pdlDialect](Operation *op) {
+                      if (op->getDialect() == pdlDialect)
+                        return WalkResult::interrupt();
+                      return WalkResult::advance();
+                    })
+                    .wasInterrupted();
+
+  OwningOpRef<TransformOpInterface> owningCopy;
+  if (hasPDL) {
+    owningCopy = OwningOpRef<TransformOpInterface>(transform->clone());
+    transform = owningCopy.get();
+  }
+
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
   return state.applyTransform(transform).checkAndReport();

diff  --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 68ad95bae3dbc..bf9a255bacad3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
+  TransformInterpreterPassBase.cpp
 
   DEPENDS
   MLIRTransformDialectTransformsIncGen

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
new file mode 100644
index 0000000000000..d347c9d42d557
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -0,0 +1,349 @@
+//===- TransformInterpreterPassBase.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Base class with shared implementation for transform dialect interpreter
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro"
+#define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro"
+
+/// Name of the attribute used for targeting the transform dialect interpreter
+/// at specific operations.
+constexpr static llvm::StringLiteral kTransformDialectTagAttrName =
+    "transform.target_tag";
+/// Value of the attribute indicating the root payload operation.
+constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
+    "payload_root";
+/// Value of the attribute indicating the container of transform operations
+/// (containing the top-level transform operation).
+constexpr static llvm::StringLiteral
+    kTransformDialectTagTransformContainerValue = "transform_container";
+
+/// Utility to parse the content of a `transformFileName` MLIR file containing
+/// a transform dialect specification.
+static LogicalResult
+parseTransformModuleFromFile(MLIRContext *context,
+                             llvm::StringRef transformFileName,
+                             OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to parse transform file";
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  return success();
+}
+
+/// Finds the single top-level transform operation with `root` as ancestor.
+/// Reports an error if there is more than one such operation and returns the
+/// first one found. Reports an error returns nullptr if no such operation
+/// found.
+static Operation *findTopLevelTransform(Operation *root,
+                                        StringRef filenameOption) {
+  ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
+  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+      [&](::mlir::transform::TransformOpInterface transformOp) {
+        if (!topLevelTransform) {
+          topLevelTransform = transformOp;
+          return WalkResult::skip();
+        }
+        auto diag = transformOp.emitError()
+                    << "more than one top-level transform op";
+        diag.attachNote(topLevelTransform.getLoc())
+            << "previous top-level transform op";
+        return WalkResult::interrupt();
+      });
+  if (walkResult.wasInterrupted())
+    return nullptr;
+  if (!topLevelTransform) {
+    auto diag = root->emitError()
+                << "could not find a nested top-level transform op";
+    diag.attachNote() << "use the '" << filenameOption
+                      << "' option to provide transform as external file";
+    return nullptr;
+  }
+  return topLevelTransform;
+}
+
+/// Finds an operation nested in `root` that has the transform dialect tag
+/// attribute with the value specified as `tag`. Assumes only one operation
+/// may have the tag. Returns nullptr if there is no such operation.
+static Operation *findOpWithTag(Operation *root, StringRef tagKey,
+                                StringRef tagValue) {
+  Operation *found = nullptr;
+  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+      [tagKey, tagValue, &found, root](Operation *op) {
+        auto attr = op->getAttrOfType<StringAttr>(tagKey);
+        if (!attr || attr.getValue() != tagValue)
+          return WalkResult::advance();
+
+        if (found) {
+          InFlightDiagnostic diag = root->emitError()
+                                    << "more than one operation with " << tagKey
+                                    << "=\"" << tagValue << "\" attribute";
+          diag.attachNote(found->getLoc()) << "first operation";
+          diag.attachNote(op->getLoc()) << "other operation";
+          return WalkResult::interrupt();
+        }
+
+        found = op;
+        return WalkResult::advance();
+      });
+  if (walkResult.wasInterrupted())
+    return nullptr;
+
+  if (!found) {
+    root->emitError() << "could not find the operation with " << tagKey << "=\""
+                      << tagValue << "\" attribute";
+  }
+  return found;
+}
+
+/// Returns the ancestor of `target` that doesn't have a parent.
+static Operation *getRootOperation(Operation *target) {
+  Operation *root = target;
+  while (root->getParentOp())
+    root = root->getParentOp();
+  return root;
+}
+
+/// Prints the CLI command running the repro with the current path.
+// TODO: make binary name optional by querying LLVM command line API for the
+// name of the current binary.
+static llvm::raw_ostream &
+printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
+               const Pass::Option<std::string> &debugPayloadRootTag,
+               const Pass::Option<std::string> &debugTransformRootTag,
+               StringRef binaryName) {
+  os << llvm::formatv(
+      "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
+      passName, debugPayloadRootTag.getArgStr(),
+      debugPayloadRootTag.empty()
+          ? StringRef(kTransformDialectTagPayloadRootValue)
+          : debugPayloadRootTag,
+      debugTransformRootTag.getArgStr(),
+      debugTransformRootTag.empty()
+          ? StringRef(kTransformDialectTagTransformContainerValue)
+          : debugTransformRootTag,
+      binaryName);
+  return os;
+}
+
+/// Prints the module rooted at `root` to `os` and appends
+/// `transformContainer` if it is not nested in `root`.
+llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root,
+                                       Operation *transform) {
+  root->print(os);
+  if (!root->isAncestor(transform))
+    transform->print(os);
+  return os;
+}
+
+/// Saves the payload and the transform IR into a temporary file and reports
+/// the file name to `os`.
+void saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
+                         Operation *transform, StringRef passName,
+                         const Pass::Option<std::string> &debugPayloadRootTag,
+                         const Pass::Option<std::string> &debugTransformRootTag,
+                         StringRef binaryName) {
+  using llvm::sys::fs::TempFile;
+  Operation *root = getRootOperation(target);
+
+  SmallVector<char, 128> tmpPath;
+  llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath);
+  llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir");
+  llvm::Expected<TempFile> tempFile = TempFile::create(tmpPath);
+  if (!tempFile) {
+    os << "could not open temporary file to save the repro\n";
+    return;
+  }
+
+  llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false);
+  printModuleForRepro(fout, root, transform);
+  fout.flush();
+  std::string filename = tempFile->TmpName;
+
+  if (tempFile->keep()) {
+    os << "could not preserve the temporary file with the repro\n";
+    return;
+  }
+
+  os << "=== Transform Interpreter Repro ===\n";
+  printReproCall(os, root->getName().getStringRef(), passName,
+                 debugPayloadRootTag, debugTransformRootTag, binaryName)
+      << " " << filename << "\n";
+  os << "===================================\n";
+}
+
+// Optionally perform debug actions requested by the user to dump IR and a
+// repro to stderr and/or a file.
+static void performOptionalDebugActions(
+    Operation *target, Operation *transform, StringRef passName,
+    const Pass::Option<std::string> &debugPayloadRootTag,
+    const Pass::Option<std::string> &debugTransformRootTag,
+    StringRef binaryName) {
+  MLIRContext *context = target->getContext();
+
+  // If we are not planning to print, bail early.
+  bool hasDebugFlags = false;
+  DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; });
+  DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; });
+  if (!hasDebugFlags)
+    return;
+
+  // We will be mutating the IR to set attributes. If this is running
+  // concurrently on several parts of a container or using a shared transform
+  // script, this would create a race. Bail in multithreaded mode and require
+  // the user to disable threading to dump repros.
+  static llvm::sys::SmartMutex<true> dbgStreamMutex;
+  if (target->getContext()->isMultithreadingEnabled()) {
+    llvm::sys::SmartScopedLock<true> lock(dbgStreamMutex);
+    llvm::dbgs() << "=======================================================\n";
+    llvm::dbgs() << "|      Transform reproducers cannot be produced       |\n";
+    llvm::dbgs() << "|              in multi-threaded mode!                |\n";
+    llvm::dbgs() << "=======================================================\n";
+    return;
+  }
+
+  Operation *root = getRootOperation(target);
+
+  // Add temporary debug / repro attributes, these must never leak out.
+  if (debugPayloadRootTag.empty()) {
+    target->setAttr(
+        kTransformDialectTagAttrName,
+        StringAttr::get(context, kTransformDialectTagPayloadRootValue));
+  }
+  if (debugTransformRootTag.empty()) {
+    transform->setAttr(
+        kTransformDialectTagAttrName,
+        StringAttr::get(context, kTransformDialectTagTransformContainerValue));
+  }
+
+  DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
+    llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
+    printReproCall(llvm::dbgs() << "cat <<EOF | ",
+                   root->getName().getStringRef(), passName,
+                   debugPayloadRootTag, debugTransformRootTag, binaryName)
+        << "\n";
+    printModuleForRepro(llvm::dbgs(), root, transform);
+    llvm::dbgs() << "\nEOF\n";
+    llvm::dbgs() << "===================================\n";
+  });
+  DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
+    saveReproToTempFile(llvm::dbgs(), target, transform, passName,
+                        debugPayloadRootTag, debugTransformRootTag, binaryName);
+  });
+}
+
+LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
+    Operation *target, StringRef passName,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const TransformOptions &options,
+    const Pass::Option<std::string> &transformFileName,
+    const Pass::Option<std::string> &debugPayloadRootTag,
+    const Pass::Option<std::string> &debugTransformRootTag,
+    StringRef binaryName) {
+
+  // Step 1
+  // ------
+  // If debugPayloadRootTag was passed, then we are in user-specified selection
+  // of the transformed IR. This corresponds to REPL debug mode. Otherwise, just
+  // apply to `target`.
+  Operation *payloadRoot = target;
+  if (!debugPayloadRootTag.empty()) {
+    payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName,
+                                debugPayloadRootTag);
+    if (!payloadRoot)
+      return failure();
+  }
+
+  // Step 2
+  // ------
+  // If a shared transform was specified separately, use it. Otherwise, the
+  // transform is embedded in the payload IR. If debugTransformRootTag was
+  // passed, then we are in user-specified selection of the transforming IR.
+  // This corresponds to REPL debug mode.
+  bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
+  Operation *transformContainer =
+      sharedTransform ? sharedTransformModule->get() : target;
+  Operation *transformRoot =
+      debugTransformRootTag.empty()
+          ? findTopLevelTransform(transformContainer,
+                                  transformFileName.getArgStr())
+          : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
+                          debugTransformRootTag);
+  if (!transformRoot)
+    return failure();
+
+  if (!transformRoot->hasTrait<PossibleTopLevelTransformOpTrait>()) {
+    return emitError(transformRoot->getLoc())
+           << "expected the transform entry point to be a top-level transform "
+              "op";
+  }
+
+  // Step 3
+  // ------
+  // Optionally perform debug actions requested by the user to dump IR and a
+  // repro to stderr and/or a file.
+  performOptionalDebugActions(target, transformRoot, passName,
+                              debugPayloadRootTag, debugTransformRootTag,
+                              binaryName);
+
+  // Step 4
+  // ------
+  // Apply the transform to the IR
+  return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot),
+                         extraMappings, options);
+}
+
+LogicalResult transform::detail::interpreterBaseInitializeImpl(
+    MLIRContext *context, StringRef transformFileName,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &module) {
+  OwningOpRef<ModuleOp> parsed;
+  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
+    return failure();
+
+  module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
new file mode 100644
index 0000000000000..02b3babc10cc2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s
+
+func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
+  return %0 : tensor<1x1x128x64xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8]
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK:       func.func @KCRSsr_to_KCRS
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
+// CHECK:             %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
+// CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME:          [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME:          [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:          ins(%[[TILE]]
+// CHECK-SAME:          outs(%[[EMPTY]]
+// CHECK-SAME:          permutation = [1, 0]
+// CHECK:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+// -----
+
+func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+  return %0 : tensor<13x15xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK:       func.func @unpack_and_extract_slice
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
+// CHECK:           %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK:           %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
+// CHECK:             %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
+// CHECK:             %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK:             %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME:          [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:          ins(%[[TILE]] : tensor<8x2xf32>)
+// CHECK-SAME:          outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-SAME:          permutation = [0, 1]
+// CHECK:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK:             %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
+// CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK:             %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2]
+}
+
+// -----
+
+func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK:       func.func @CKkc_to_KC
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
+// CHECK:           %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
+// CHECK:             %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK:             %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
+// CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME:          [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:          ins(%[[TILE]]
+// CHECK-SAME:          outs(%[[EMPTY]]
+// CHECK-SAME:          permutation = [0, 1]
+// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:          [%[[K]], %[[C]]] [32, 8] [1, 1]
+
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8]
+}

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index a03161f9648fe..cc734c24d4f56 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -55,114 +55,3 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
 //                They have the same type, so the insert_slice op is folded
 //                away.
 // CHECK:         return %[[TRANSP]]
-
-// -----
-
-// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack"  %s | FileCheck %s --check-prefix=CHECK-TRANS
-
-func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
-  return %0 : tensor<1x1x128x64xf32>
-}
-
-transform.sequence failures(propagate) {
-  ^bb0(%arg1: !pdl.operation):
-    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-    %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8]
-}
-// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS:       func.func @KCRSsr_to_KCRS
-// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
-// CHECK-TRANS:             %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
-// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME:          [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME:          [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32>
-// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-TRANS-SAME:          ins(%[[TILE]]
-// CHECK-TRANS-SAME:          outs(%[[EMPTY]]
-// CHECK-TRANS-SAME:          permutation = [1, 0]
-// CHECK-TRANS:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
-
-// -----
-
-func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
-  return %0 : tensor<13x15xf32>
-}
-// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
-// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
-// CHECK-TRANS-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
-// CHECK-TRANS:       func.func @unpack_and_extract_slice
-// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS:         %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:           %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
-// CHECK-TRANS:           %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:             %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
-// CHECK-TRANS:             %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK-TRANS:             %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
-// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME:          [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK-TRANS:             %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
-// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME:          [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
-// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-TRANS-SAME:          ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-TRANS-SAME:          outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-TRANS-SAME:          permutation = [0, 1]
-// CHECK-TRANS:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
-// CHECK-TRANS-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK-TRANS:             %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
-// CHECK-TRANS-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK-TRANS:             %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-
-transform.sequence failures(propagate) {
-  ^bb0(%arg1: !pdl.operation):
-    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2]
-}
-
-// -----
-
-func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
-  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
-  return %0 : tensor<128x256xf32>
-}
-// CHECK-TRANS-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-TRANS-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS:       func.func @CKkc_to_KC
-// CHECK-TRANS-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS:         %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:           %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
-// CHECK-TRANS:             %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-TRANS:             %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
-// CHECK-TRANS:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME:          [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK-TRANS:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK-TRANS:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK-TRANS:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-TRANS-SAME:          ins(%[[TILE]]
-// CHECK-TRANS-SAME:          outs(%[[EMPTY]]
-// CHECK-TRANS-SAME:          permutation = [0, 1]
-// CHECK-TRANS:             %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME:          [%[[K]], %[[C]]] [32, 8] [1, 1]
-
-
-transform.sequence failures(propagate) {
-  ^bb0(%arg1: !pdl.operation):
-    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8]
-}

diff  --git a/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
new file mode 100644
index 0000000000000..09ebc45ccb57b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack"  %s | FileCheck %s
+
+func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> {
+  %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>
+  return %0 : tensor<1x1x4x8x8x32xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 64, 8)>
+// CHECK:       func.func @KCRS_to_KCRSsr
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
+// CHECK:             %[[IN_R_SZ:.+]] = affine.min #[[MAP1]](%[[R]])
+// CHECK:             %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
+// CHECK:             %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]])
+// CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME:          [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1]
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:          ins(%[[TILE]]
+// CHECK-SAME:          outs(%[[EMPTY]]
+// CHECK-SAME:          permutation = [1, 0]
+// CHECK:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1]
+}
+
+// -----
+
+func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %arg2: f32) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+// CHECK:       func.func @pad_and_pack
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[SRC_SLICE]] = tensor.extract_slice %[[SRC]]
+// CHECK:             %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
+// CHECK:               tensor.yield %[[PAD_VAL]]
+// CHECK:             } : tensor<?x?xf32> to tensor<8x2xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[PAD]] : tensor<8x2xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-SAME:      permutation = [0, 1]
+// CHECK:         %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1]
+}
+
+// -----
+
+
+func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> {
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
+  return %0 : tensor<32x4x32x8xf32>
+}
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG:   #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 8)>
+// CHECK:       func.func @KC_to_CKkc
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
+// CHECK:           %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
+// CHECK-DAG:         %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK-DAG:         %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]])
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]])
+// CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME:          [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1]
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME:          [0, 0] [32, 8] [1, 1] : tensor<?x?xf32> to tensor<32x8xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:          ins(%[[TILE]]
+// CHECK-SAME:          outs(%[[EMPTY]]
+// CHECK-SAME:          permutation = [0, 1]
+// CHECK:             %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
+// CHECK:             %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME:          [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1]
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index e2063e181480b..6214e4e527f14 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -1,7 +1,5 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
 
-// -----
-
 func.func @dot(%x: memref<?xf32, strided<[1], offset: ?>>,
           %y: memref<?xf32, strided<[1], offset: ?>>,
           %v: memref<f32>) {

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
new file mode 100644
index 0000000000000..efb9b8429b19c
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=payload debug-transform-root-tag=transform})" \
+// RUN:             --allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+// expected-error @below {{could not find the operation with transform.target_tag="payload" attribute}}
+module {
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+  }
+}
+
+// -----
+
+// expected-error @below {{could not find the operation with transform.target_tag="transform" attribute}}
+module {
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+  }
+
+  module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+// expected-error @below {{more than one operation with transform.target_tag="transform" attribute}}
+module {
+  // expected-note @below {{first operation}}
+  transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
+  ^bb0(%arg0: !transform.any_op):
+  }
+
+  // expected-note @below {{other operation}}
+  transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
+  ^bb0(%arg0: !transform.any_op):
+  }
+
+  module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+module {
+  // expected-error @below {{expected the transform entry point to be a top-level transform op}}
+  func.func private @foo() attributes {transform.target_tag="transform"}
+
+  module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+module {
+  transform.sequence failures(suppress) attributes {transform.target_tag="transform"} {
+  ^bb0(%arg0: !transform.any_op):
+    transform.test_print_remark_at_operand %arg0, "payload" : !transform.any_op
+  }
+
+  // This will not be executed because it's not tagged.
+  transform.sequence failures(suppress)  {
+  ^bb0(%arg0: !transform.any_op):
+    transform.test_print_remark_at_operand %arg0, "some other text that is not printed" : !transform.any_op
+  }
+
+  module {
+    module {}
+    // expected-remark @below {{payload}}
+    module attributes {transform.target_tag="payload"} {}
+    module {}
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
new file mode 100644
index 0000000000000..c3e5bff898de5
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s
+// No need to check anything else than parsing here, this is being used by another test as data.
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.test_print_remark_at_operand %arg0, "outer" : !transform.any_op
+  transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
+  ^bb1(%arg1: !transform.any_op):
+    transform.test_print_remark_at_operand %arg1, "inner" : !transform.any_op
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-external.mlir b/mlir/test/Dialect/Transform/test-interpreter-external.mlir
new file mode 100644
index 0000000000000..5ac6b66c817af
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-source.mlir})" \
+// RUN:             --verify-diagnostics
+
+// The schedule in the separate file emits remarks at the payload root.
+
+// expected-remark @below {{outer}}
+// expected-remark @below {{inner}}
+module {}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b0b16abc0399c..f470606c78cf5 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1091,3 +1091,22 @@ transform.sequence failures(propagate) {
   // expected-error @below {{attempting to assign a null parameter to this transform value}}
   %0 = transform.test_produce_null_param : !transform.param<i64>
 }
+
+// -----
+
+// expected-error @below {{could not find a nested top-level transform op}}
+// expected-note @below {{use the 'transform-file-name' option to provide transform as external file}}
+module {
+}
+
+// -----
+
+// expected-note @below {{previous top-level transform op}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+}
+
+// expected-error @below {{ore than one top-level transform op}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+}

diff  --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index 0119fb07d1c68..56c7031481b95 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -19,4 +19,5 @@ add_mlir_library(MLIRTestTransformDialect
   MLIRPass
   MLIRPDLDialect
   MLIRTransformDialect
+  MLIRTransformDialectTransforms
 )

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 7d049eb98be51..afc7ed739fdc0 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
@@ -21,16 +22,22 @@ using namespace mlir;
 namespace {
 /// Simple pass that applies transform dialect ops directly contained in a
 /// module.
+
+template <typename Derived>
+class ModulePassWrapper : public PassWrapper<Derived, OperationPass<ModuleOp>> {
+};
+
 class TestTransformDialectInterpreterPass
-    : public PassWrapper<TestTransformDialectInterpreterPass,
-                         OperationPass<ModuleOp>> {
+    : public transform::TransformInterpreterPassBase<
+          TestTransformDialectInterpreterPass, ModulePassWrapper> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestTransformDialectInterpreterPass)
 
   TestTransformDialectInterpreterPass() = default;
   TestTransformDialectInterpreterPass(
-      const TestTransformDialectInterpreterPass &) {}
+      const TestTransformDialectInterpreterPass &pass)
+      : TransformInterpreterPassBase(pass) {}
 
   StringRef getArgument() const override {
     return "test-transform-dialect-interpreter";
@@ -101,15 +108,12 @@ class TestTransformDialectInterpreterPass
           getContext(), bindSecondExtraToParams, extraMappingStorage));
     }
 
-    ModuleOp module = getOperation();
-    for (auto op :
-         module.getBody()->getOps<transform::TransformOpInterface>()) {
-      if (failed(transform::applyTransforms(
-              module, op, extraMapping,
-              transform::TransformOptions().enableExpensiveChecks(
-                  enableExpensiveChecks))))
-        return signalPassFailure();
-    }
+    options = options.enableExpensiveChecks(enableExpensiveChecks);
+    if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
+            getOperation(), getArgument(), getSharedTransformModule(),
+            extraMapping, options, transformFileName, debugPayloadRootTag,
+            debugTransformRootTag, getBinaryName())))
+      return signalPassFailure();
   }
 
   Option<bool> enableExpensiveChecks{
@@ -134,6 +138,25 @@ class TestTransformDialectInterpreterPass
       *this, "bind-second-extra-to-params",
       llvm::cl::desc("bind the second extra argument of the top-level op to "
                      "the given integer parameters")};
+  Option<std::string> transformFileName{
+      *this, "transform-file-name", llvm::cl::init(""),
+      llvm::cl::desc(
+          "Optional filename containing a transform dialect specification to "
+          "apply. If left empty, the IR is assumed to contain one top-level "
+          "transform dialect operation somewhere in the module.")};
+  Option<std::string> debugPayloadRootTag{
+      *this, "debug-payload-root-tag", llvm::cl::init(""),
+      llvm::cl::desc(
+          "Select the operation with 'transform.target_tag' attribute having "
+          "the given value as payload IR root. If empty select the pass anchor "
+          "operation as the payload IR root.")};
+  Option<std::string> debugTransformRootTag{
+      *this, "debug-transform-root-tag", llvm::cl::init(""),
+      llvm::cl::desc(
+          "Select the operation with 'transform.target_tag' attribute having "
+          "the given value as container IR for top-level transform ops. This "
+          "allows user control on what transformation to apply. If empty, "
+          "select the container of the top-level transform op.")};
 };
 
 struct TestTransformDialectEraseSchedulePass

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1534452a10eea..fc9a057111b04 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9204,8 +9204,10 @@ cc_library(
     deps = [
         ":Analysis",
         ":IR",
+        ":Parser",
         ":Pass",
         ":SideEffectInterfaces",
+        ":Support",
         ":TransformDialect",
         ":TransformDialectTransformsIncGen",
         "//llvm:Support",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index a3e96b5c633b7..01524c08c8ccf 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -319,6 +319,7 @@ cc_library(
         "//mlir:PDLDialect",
         "//mlir:Pass",
         "//mlir:TransformDialect",
+        "//mlir:TransformDialectTransforms",
     ],
 )
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
index 6e997e9f47e77..adfb3d0360203 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
@@ -13,7 +13,12 @@ package(default_visibility = ["//visibility:public"])
             "//mlir:mlir-opt",
             "//mlir:mlir-translate",
             "//mlir/test:lit_data",
-        ],
+        ] + glob([
+            "Transform/*-source.mlir",
+        ])
+    )
+    for src in glob(
+        include=["**/*.mlir"],
+        exclude=["Transform/*-source.mlir"]
     )
-    for src in glob(["**/*.mlir"])
 ]


        


More information about the Mlir-commits mailing list