[Mlir-commits] [mlir] 6e44f11 - [mlir] provide a base class for transform interpreter passes
Alex Zinenko
llvmlistbot at llvm.org
Fri Feb 3 06:12:41 PST 2023
Author: Alex Zinenko
Date: 2023-02-03T14:12:31Z
New Revision: 6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1
URL: https://github.com/llvm/llvm-project/commit/6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1
DIFF: https://github.com/llvm/llvm-project/commit/6e44f11ed30cc2e9d5453f162403c7e58ee6e5e1.diff
LOG: [mlir] provide a base class for transform interpreter passes
The transform dialect infrastructure does not provide a default
interpreter pass and instead expects users to create their own to ensure
all relevant extensions and dependent dialects are loaded. Provide a
base class for implementing such passes that includes the additional
facilities for debugging and is aware of the multithreaded nature of
pass execution.
Reviewed By: pifon2a, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D142729
Added:
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
mlir/test/Dialect/Transform/test-interpreter-debug.mlir
mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
mlir/test/Dialect/Transform/test-interpreter-external.mlir
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/CMakeLists.txt
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 063b6dec7dc97..bae5d4568a556 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -25,7 +25,9 @@ class TransformOpInterface;
/// TransformState.
class TransformOptions {
public:
- TransformOptions() {}
+ TransformOptions() = default;
+ TransformOptions(const TransformOptions &) = default;
+ TransformOptions &operator=(const TransformOptions &) = default;
/// Requests computationally expensive checks of the transform and payload IR
/// well-formedness to be performed before each transformation. In particular,
@@ -200,7 +202,8 @@ class TransformState {
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- assert(state.regionStack.back()->isProperAncestor(®ion) &&
+ assert(((state.regionStack.size() == 1 && !state.regionStack.back()) ||
+ state.regionStack.back()->isProperAncestor(®ion)) &&
"scope started at a non-nested region");
state.regionStack.push_back(®ion);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
new file mode 100644
index 0000000000000..1b1ad6a74ecd5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -0,0 +1,164 @@
+//===- TransformInterpreterPassBase.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Base class with shared implementation for transform dialect interpreter
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+struct LogicalResult;
+class MLIRContext;
+class ModuleOp;
+class Operation;
+template <typename>
+class OwningOpRef;
+class Region;
+
+namespace transform {
+namespace detail {
+/// Template-free implementation of TransformInterpreterPassBase::initialize.
+LogicalResult
+interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &module);
+
+/// Template-free implementation of
+/// TransformInterpreterPassBase::runOnOperation.
+LogicalResult interpreterBaseRunOnOperationImpl(
+ Operation *target, StringRef passName,
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+ ArrayRef<ArrayRef<MappedValue>> extraMappings,
+ const TransformOptions &options,
+ const Pass::Option<std::string> &transformFileName,
+ const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ StringRef binaryName);
+} // namespace detail
+
+/// Base class for transform dialect interpreter passes that can consume and
+/// dump transform dialect scripts in separate files. The pass is controlled by
+/// three string options:
+///
+/// - transformFileName: if non-empty, the name of the file containing the
+/// transform script. If empty, `debugTransformRootTag` is considered or the
+/// pass root operation must contain a single top-level transform op that
+/// will be interpreted.
+/// - debugPayloadRootTag: if non-empty, the value of the attribute named
+/// `kTransformDialectTagAttrName` indicating the single op that is
+/// considered the payload root of the transform interpreter; otherwise, the
+/// root operation of the pass is used.
+/// - debugTransformRootTag: if non-empty, the value of the attribute named
+/// `kTransformDialectTagAttrName` indicating the single top-level transform
+/// op contained in the payload root to be used as the entry point by the
+/// transform interpreter; mutually exclusive with `transformFileName`.
+///
+/// The pass runs the transform dialect interpreter as directed by the options.
+/// It also provides the mechanism to dump reproducers into stderr
+/// (-debug-only=transform-dialect-dump-repro) or into a temporary file
+/// (-debug-only=transform-dialect-save-repro) that can be used with this
+/// pass in a standalone mode.
+///
+/// Concrete passes must derive from this class instead of their generated base
+/// class (or PassWrapper), and supply themselves and the generated base class
+/// as template arguments. They are *not* expected to to implement `initialize`
+/// or `runOnOperation`. They *are* expected to call the copy constructor of
+/// this class in their copy constructors, short of which the file-based
+/// transform dialect script injection facility will become nonoperational.
+///
+/// Concrete passes may implement the `runBeforeInterpreter` and
+/// `runAfterInterpreter` to customize the behavior of the pass.
+template <typename Concrete, template <typename> typename GeneratedBase>
+class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
+public:
+ explicit TransformInterpreterPassBase(
+ const TransformOptions &options = TransformOptions())
+ : options(options) {}
+
+ TransformInterpreterPassBase(const TransformInterpreterPassBase &pass) {
+ sharedTransformModule = pass.sharedTransformModule;
+ options = pass.options;
+ }
+
+ static StringLiteral getBinaryName() { return "mlir-opt"; }
+
+ LogicalResult initialize(MLIRContext *context) override {
+
+#define REQUIRE_PASS_OPTION(NAME) \
+ static_assert( \
+ std::is_same_v< \
+ std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
+ Pass::Option<std::string>>, \
+ "required " #NAME " string pass option is missing")
+
+ REQUIRE_PASS_OPTION(transformFileName);
+ REQUIRE_PASS_OPTION(debugPayloadRootTag);
+ REQUIRE_PASS_OPTION(debugTransformRootTag);
+
+#undef REQUIRE_PASS_OPTION
+
+ StringRef transformFileName =
+ static_cast<Concrete *>(this)->transformFileName;
+ return detail::interpreterBaseInitializeImpl(context, transformFileName,
+ sharedTransformModule);
+ }
+
+ /// Hook for passes to run additional logic in the pass before the
+ /// interpreter. If failure is returned, the pass fails and the interpreter is
+ /// not run.
+ LogicalResult runBeforeInterpreter(Operation *) { return success(); }
+
+ /// Hook for passes to run additional logic in the pass after the interpreter.
+ /// Only runs if everything succeeded before. If failure is returned, the pass
+ /// fails.
+ LogicalResult runAfterInterpreter(Operation *) { return success(); }
+
+ void runOnOperation() override {
+ auto *pass = static_cast<Concrete *>(this);
+ Operation *op = pass->getOperation();
+ StringRef binaryName = Concrete::getBinaryName();
+ if (failed(pass->runBeforeInterpreter(op)) ||
+ failed(detail::interpreterBaseRunOnOperationImpl(
+ op, pass->getArgument(), sharedTransformModule,
+ /*extraMappings=*/{}, options, pass->transformFileName,
+ pass->debugPayloadRootTag, pass->debugTransformRootTag,
+ binaryName)) ||
+ failed(pass->runAfterInterpreter(op))) {
+ return pass->signalPassFailure();
+ }
+ }
+
+protected:
+ /// Transform interpreter options.
+ TransformOptions options;
+
+ /// Returns a read-only reference to shared transform module.
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &
+ getSharedTransformModule() const {
+ return sharedTransformModule;
+ }
+
+private:
+ /// The separate transform module to be used for transformations, shared
+ /// across multiple instances of the pass if it is applied in parallel to
+ /// avoid potentially expensive cloning. MUST NOT be modified after the pass
+ /// has been initialized.
+ std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERPASSBASE_H
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5ecc1f47573c6..5eb0eedb6d7af 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OwningOpRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
@@ -828,6 +830,25 @@ transform::applyTransforms(Operation *payloadRoot,
}
#endif // NDEBUG
+ // If the transform dialect may use PDL which may modify the IR, clone it
+ // before use to avoid concurrent modification in case this is being called
+ // from pass instances running concurrently with a shared transform script.
+ auto *pdlDialect =
+ transform->getContext()->getLoadedDialect<pdl::PDLDialect>();
+ bool hasPDL = transform
+ .walk([pdlDialect](Operation *op) {
+ if (op->getDialect() == pdlDialect)
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
+
+ OwningOpRef<TransformOpInterface> owningCopy;
+ if (hasPDL) {
+ owningCopy = OwningOpRef<TransformOpInterface>(transform->clone());
+ transform = owningCopy.get();
+ }
+
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
return state.applyTransform(transform).checkAndReport();
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 68ad95bae3dbc..bf9a255bacad3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
+ TransformInterpreterPassBase.cpp
DEPENDS
MLIRTransformDialectTransformsIncGen
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
new file mode 100644
index 0000000000000..d347c9d42d557
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -0,0 +1,349 @@
+//===- TransformInterpreterPassBase.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Base class with shared implementation for transform dialect interpreter
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro"
+#define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro"
+
+/// Name of the attribute used for targeting the transform dialect interpreter
+/// at specific operations.
+constexpr static llvm::StringLiteral kTransformDialectTagAttrName =
+ "transform.target_tag";
+/// Value of the attribute indicating the root payload operation.
+constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
+ "payload_root";
+/// Value of the attribute indicating the container of transform operations
+/// (containing the top-level transform operation).
+constexpr static llvm::StringLiteral
+ kTransformDialectTagTransformContainerValue = "transform_container";
+
+/// Utility to parse the content of a `transformFileName` MLIR file containing
+/// a transform dialect specification.
+static LogicalResult
+parseTransformModuleFromFile(MLIRContext *context,
+ llvm::StringRef transformFileName,
+ OwningOpRef<ModuleOp> &transformModule) {
+ if (transformFileName.empty()) {
+ LLVM_DEBUG(
+ DBGS() << "no transform file name specified, assuming the transform "
+ "module is embedded in the IR next to the top-level\n");
+ return success();
+ }
+ // Parse transformFileName content into a ModuleOp.
+ std::string errorMessage;
+ auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+ if (!memoryBuffer) {
+ return emitError(FileLineColLoc::get(
+ StringAttr::get(context, transformFileName), 0, 0))
+ << "failed to parse transform file";
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ transformModule =
+ OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+ return success();
+}
+
+/// Finds the single top-level transform operation with `root` as ancestor.
+/// Reports an error if there is more than one such operation and returns the
+/// first one found. Reports an error returns nullptr if no such operation
+/// found.
+static Operation *findTopLevelTransform(Operation *root,
+ StringRef filenameOption) {
+ ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
+ WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+ [&](::mlir::transform::TransformOpInterface transformOp) {
+ if (!topLevelTransform) {
+ topLevelTransform = transformOp;
+ return WalkResult::skip();
+ }
+ auto diag = transformOp.emitError()
+ << "more than one top-level transform op";
+ diag.attachNote(topLevelTransform.getLoc())
+ << "previous top-level transform op";
+ return WalkResult::interrupt();
+ });
+ if (walkResult.wasInterrupted())
+ return nullptr;
+ if (!topLevelTransform) {
+ auto diag = root->emitError()
+ << "could not find a nested top-level transform op";
+ diag.attachNote() << "use the '" << filenameOption
+ << "' option to provide transform as external file";
+ return nullptr;
+ }
+ return topLevelTransform;
+}
+
+/// Finds an operation nested in `root` that has the transform dialect tag
+/// attribute with the value specified as `tag`. Assumes only one operation
+/// may have the tag. Returns nullptr if there is no such operation.
+static Operation *findOpWithTag(Operation *root, StringRef tagKey,
+ StringRef tagValue) {
+ Operation *found = nullptr;
+ WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+ [tagKey, tagValue, &found, root](Operation *op) {
+ auto attr = op->getAttrOfType<StringAttr>(tagKey);
+ if (!attr || attr.getValue() != tagValue)
+ return WalkResult::advance();
+
+ if (found) {
+ InFlightDiagnostic diag = root->emitError()
+ << "more than one operation with " << tagKey
+ << "=\"" << tagValue << "\" attribute";
+ diag.attachNote(found->getLoc()) << "first operation";
+ diag.attachNote(op->getLoc()) << "other operation";
+ return WalkResult::interrupt();
+ }
+
+ found = op;
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return nullptr;
+
+ if (!found) {
+ root->emitError() << "could not find the operation with " << tagKey << "=\""
+ << tagValue << "\" attribute";
+ }
+ return found;
+}
+
+/// Returns the ancestor of `target` that doesn't have a parent.
+static Operation *getRootOperation(Operation *target) {
+ Operation *root = target;
+ while (root->getParentOp())
+ root = root->getParentOp();
+ return root;
+}
+
+/// Prints the CLI command running the repro with the current path.
+// TODO: make binary name optional by querying LLVM command line API for the
+// name of the current binary.
+static llvm::raw_ostream &
+printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
+ const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ StringRef binaryName) {
+ os << llvm::formatv(
+ "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
+ passName, debugPayloadRootTag.getArgStr(),
+ debugPayloadRootTag.empty()
+ ? StringRef(kTransformDialectTagPayloadRootValue)
+ : debugPayloadRootTag,
+ debugTransformRootTag.getArgStr(),
+ debugTransformRootTag.empty()
+ ? StringRef(kTransformDialectTagTransformContainerValue)
+ : debugTransformRootTag,
+ binaryName);
+ return os;
+}
+
+/// Prints the module rooted at `root` to `os` and appends
+/// `transformContainer` if it is not nested in `root`.
+llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root,
+ Operation *transform) {
+ root->print(os);
+ if (!root->isAncestor(transform))
+ transform->print(os);
+ return os;
+}
+
+/// Saves the payload and the transform IR into a temporary file and reports
+/// the file name to `os`.
+void saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
+ Operation *transform, StringRef passName,
+ const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ StringRef binaryName) {
+ using llvm::sys::fs::TempFile;
+ Operation *root = getRootOperation(target);
+
+ SmallVector<char, 128> tmpPath;
+ llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath);
+ llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir");
+ llvm::Expected<TempFile> tempFile = TempFile::create(tmpPath);
+ if (!tempFile) {
+ os << "could not open temporary file to save the repro\n";
+ return;
+ }
+
+ llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false);
+ printModuleForRepro(fout, root, transform);
+ fout.flush();
+ std::string filename = tempFile->TmpName;
+
+ if (tempFile->keep()) {
+ os << "could not preserve the temporary file with the repro\n";
+ return;
+ }
+
+ os << "=== Transform Interpreter Repro ===\n";
+ printReproCall(os, root->getName().getStringRef(), passName,
+ debugPayloadRootTag, debugTransformRootTag, binaryName)
+ << " " << filename << "\n";
+ os << "===================================\n";
+}
+
+// Optionally perform debug actions requested by the user to dump IR and a
+// repro to stderr and/or a file.
+static void performOptionalDebugActions(
+ Operation *target, Operation *transform, StringRef passName,
+ const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ StringRef binaryName) {
+ MLIRContext *context = target->getContext();
+
+ // If we are not planning to print, bail early.
+ bool hasDebugFlags = false;
+ DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; });
+ DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; });
+ if (!hasDebugFlags)
+ return;
+
+ // We will be mutating the IR to set attributes. If this is running
+ // concurrently on several parts of a container or using a shared transform
+ // script, this would create a race. Bail in multithreaded mode and require
+ // the user to disable threading to dump repros.
+ static llvm::sys::SmartMutex<true> dbgStreamMutex;
+ if (target->getContext()->isMultithreadingEnabled()) {
+ llvm::sys::SmartScopedLock<true> lock(dbgStreamMutex);
+ llvm::dbgs() << "=======================================================\n";
+ llvm::dbgs() << "| Transform reproducers cannot be produced |\n";
+ llvm::dbgs() << "| in multi-threaded mode! |\n";
+ llvm::dbgs() << "=======================================================\n";
+ return;
+ }
+
+ Operation *root = getRootOperation(target);
+
+ // Add temporary debug / repro attributes, these must never leak out.
+ if (debugPayloadRootTag.empty()) {
+ target->setAttr(
+ kTransformDialectTagAttrName,
+ StringAttr::get(context, kTransformDialectTagPayloadRootValue));
+ }
+ if (debugTransformRootTag.empty()) {
+ transform->setAttr(
+ kTransformDialectTagAttrName,
+ StringAttr::get(context, kTransformDialectTagTransformContainerValue));
+ }
+
+ DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, {
+ llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
+ printReproCall(llvm::dbgs() << "cat <<EOF | ",
+ root->getName().getStringRef(), passName,
+ debugPayloadRootTag, debugTransformRootTag, binaryName)
+ << "\n";
+ printModuleForRepro(llvm::dbgs(), root, transform);
+ llvm::dbgs() << "\nEOF\n";
+ llvm::dbgs() << "===================================\n";
+ });
+ DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
+ saveReproToTempFile(llvm::dbgs(), target, transform, passName,
+ debugPayloadRootTag, debugTransformRootTag, binaryName);
+ });
+}
+
+LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
+ Operation *target, StringRef passName,
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+ ArrayRef<ArrayRef<MappedValue>> extraMappings,
+ const TransformOptions &options,
+ const Pass::Option<std::string> &transformFileName,
+ const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ StringRef binaryName) {
+
+ // Step 1
+ // ------
+ // If debugPayloadRootTag was passed, then we are in user-specified selection
+ // of the transformed IR. This corresponds to REPL debug mode. Otherwise, just
+ // apply to `target`.
+ Operation *payloadRoot = target;
+ if (!debugPayloadRootTag.empty()) {
+ payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName,
+ debugPayloadRootTag);
+ if (!payloadRoot)
+ return failure();
+ }
+
+ // Step 2
+ // ------
+ // If a shared transform was specified separately, use it. Otherwise, the
+ // transform is embedded in the payload IR. If debugTransformRootTag was
+ // passed, then we are in user-specified selection of the transforming IR.
+ // This corresponds to REPL debug mode.
+ bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
+ Operation *transformContainer =
+ sharedTransform ? sharedTransformModule->get() : target;
+ Operation *transformRoot =
+ debugTransformRootTag.empty()
+ ? findTopLevelTransform(transformContainer,
+ transformFileName.getArgStr())
+ : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
+ debugTransformRootTag);
+ if (!transformRoot)
+ return failure();
+
+ if (!transformRoot->hasTrait<PossibleTopLevelTransformOpTrait>()) {
+ return emitError(transformRoot->getLoc())
+ << "expected the transform entry point to be a top-level transform "
+ "op";
+ }
+
+ // Step 3
+ // ------
+ // Optionally perform debug actions requested by the user to dump IR and a
+ // repro to stderr and/or a file.
+ performOptionalDebugActions(target, transformRoot, passName,
+ debugPayloadRootTag, debugTransformRootTag,
+ binaryName);
+
+ // Step 4
+ // ------
+ // Apply the transform to the IR
+ return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot),
+ extraMappings, options);
+}
+
+LogicalResult transform::detail::interpreterBaseInitializeImpl(
+ MLIRContext *context, StringRef transformFileName,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &module) {
+ OwningOpRef<ModuleOp> parsed;
+ if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
+ return failure();
+
+ module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
new file mode 100644
index 0000000000000..02b3babc10cc2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s
+
+func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
+ return %0 : tensor<1x1x128x64xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8]
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK: func.func @KCRSsr_to_KCRS
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
+// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK-SAME: permutation = [1, 0]
+// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+// -----
+
+func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+ return %0 : tensor<13x15xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK: func.func @unpack_and_extract_slice
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
+// CHECK: %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
+// CHECK: %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
+// CHECK: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
+// CHECK: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-SAME: permutation = [0, 1]
+// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
+// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2]
+}
+
+// -----
+
+func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK: func.func @CKkc_to_KC
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
+// CHECK: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
+// CHECK: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK-SAME: permutation = [0, 1]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1]
+
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8]
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index a03161f9648fe..cc734c24d4f56 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -55,114 +55,3 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
// They have the same type, so the insert_slice op is folded
// away.
// CHECK: return %[[TRANSP]]
-
-// -----
-
-// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-unpack" %s | FileCheck %s --check-prefix=CHECK-TRANS
-
-func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> {
- %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32>
- return %0 : tensor<1x1x128x64xf32>
-}
-
-transform.sequence failures(propagate) {
- ^bb0(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 32, 8]
-}
-// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS: func.func @KCRSsr_to_KCRS
-// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
-// CHECK-TRANS: %[[IN_S:.+]] = affine.apply #[[MAP1]](%[[S]])
-// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME: [0, 0, %[[IN_R]], %[[IN_S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x8x32xf32> to tensor<8x32xf32>
-// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-TRANS-SAME: ins(%[[TILE]]
-// CHECK-TRANS-SAME: outs(%[[EMPTY]]
-// CHECK-TRANS-SAME: permutation = [1, 0]
-// CHECK-TRANS: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
-
-// -----
-
-func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13x15xf32>) -> tensor<13x15xf32> {
- %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
- return %0 : tensor<13x15xf32>
-}
-// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 13, 8)>
-// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 15, 2)>
-// CHECK-TRANS-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
-// CHECK-TRANS: func.func @unpack_and_extract_slice
-// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS: %{{.+}} = scf.for %[[I:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %[[OUT_I_SZ:.+]] = affine.min #[[MAP0]](%[[I]])
-// CHECK-TRANS: %{{.+}} = scf.for %[[J:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %[[OUT_J_SZ:.+]] = affine.min #[[MAP1]](%[[J]])
-// CHECK-TRANS: %[[IN_I:.+]] = affine.apply #[[MAP2]](%[[I]])
-// CHECK-TRANS: %[[IN_J:.+]] = affine.apply #[[MAP3]](%[[J]])
-// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK-TRANS: %[[ITER_SLICE:.+]] = tensor.extract_slice %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
-// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
-// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-TRANS-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-TRANS-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-TRANS-SAME: permutation = [0, 1]
-// CHECK-TRANS: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
-// CHECK-TRANS-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK-TRANS: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
-// CHECK-TRANS-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-// CHECK-TRANS: %[[INSERT2:.+]] = tensor.insert_slice %[[INSERT1]] into %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
-
-transform.sequence failures(propagate) {
- ^bb0(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1, %loops:2 = transform.structured.tile_to_scf_for %0 [8, 2]
-}
-
-// -----
-
-func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
- %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
- return %0 : tensor<128x256xf32>
-}
-// CHECK-TRANS-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-TRANS-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
-// CHECK-TRANS: func.func @CKkc_to_KC
-// CHECK-TRANS-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-TRANS-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-TRANS: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
-// CHECK-TRANS: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-TRANS: %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
-// CHECK-TRANS: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-TRANS-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK-TRANS: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-TRANS-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK-TRANS: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK-TRANS: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-TRANS-SAME: ins(%[[TILE]]
-// CHECK-TRANS-SAME: outs(%[[EMPTY]]
-// CHECK-TRANS-SAME: permutation = [0, 1]
-// CHECK-TRANS: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
-// CHECK-TRANS-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1]
-
-
-transform.sequence failures(propagate) {
- ^bb0(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1, %loops:2 = transform.structured.tile_to_scf_for %0 [32, 8]
-}
diff --git a/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
new file mode 100644
index 0000000000000..09ebc45ccb57b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-tesnor-pack-tile.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt -split-input-file --test-transform-dialect-interpreter --canonicalize --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
+
+func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8x32xf32>) -> tensor<1x1x4x8x8x32xf32> {
+ %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>
+ return %0 : tensor<1x1x4x8x8x32xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 64, 8)>
+// CHECK: func.func @KCRS_to_KCRSsr
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
+// CHECK: %[[IN_R_SZ:.+]] = affine.min #[[MAP1]](%[[R]])
+// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
+// CHECK: %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK-SAME: permutation = [1, 0]
+// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1]
+}
+
+// -----
+
+func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %arg2: f32) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
+}
+// CHECK: func.func @pad_and_pack
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[SRC_SLICE]] = tensor.extract_slice %[[SRC]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
+// CHECK: tensor.yield %[[PAD_VAL]]
+// CHECK: } : tensor<?x?xf32> to tensor<8x2xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
+// CHECK-SAME: permutation = [0, 1]
+// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1]
+}
+
+// -----
+
+
+func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> {
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
+ return %0 : tensor<32x4x32x8xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 8)>
+// CHECK: func.func @KC_to_CKkc
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
+// CHECK: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
+// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
+// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]])
+// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
+// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
+// CHECK-SAME: [0, 0] [32, 8] [1, 1] : tensor<?x?xf32> to tensor<32x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK-SAME: permutation = [0, 1]
+// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
+// CHECK: %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
+// CHECK-SAME: [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:2 = transform.structured.tile_to_scf_for %0 [1, 1]
+}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index e2063e181480b..6214e4e527f14 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -1,7 +1,5 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
-// -----
-
func.func @dot(%x: memref<?xf32, strided<[1], offset: ?>>,
%y: memref<?xf32, strided<[1], offset: ?>>,
%v: memref<f32>) {
diff --git a/mlir/test/Dialect/Transform/test-interpreter-debug.mlir b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
new file mode 100644
index 0000000000000..efb9b8429b19c
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-debug.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{debug-payload-root-tag=payload debug-transform-root-tag=transform})" \
+// RUN: --allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+// expected-error @below {{could not find the operation with transform.target_tag="payload" attribute}}
+module {
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ }
+}
+
+// -----
+
+// expected-error @below {{could not find the operation with transform.target_tag="transform" attribute}}
+module {
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ }
+
+ module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+// expected-error @below {{more than one operation with transform.target_tag="transform" attribute}}
+module {
+ // expected-note @below {{first operation}}
+ transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
+ ^bb0(%arg0: !transform.any_op):
+ }
+
+ // expected-note @below {{other operation}}
+ transform.sequence failures(propagate) attributes {transform.target_tag="transform"} {
+ ^bb0(%arg0: !transform.any_op):
+ }
+
+ module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+module {
+ // expected-error @below {{expected the transform entry point to be a top-level transform op}}
+ func.func private @foo() attributes {transform.target_tag="transform"}
+
+ module attributes {transform.target_tag="payload"} {}
+}
+
+// -----
+
+module {
+ transform.sequence failures(suppress) attributes {transform.target_tag="transform"} {
+ ^bb0(%arg0: !transform.any_op):
+ transform.test_print_remark_at_operand %arg0, "payload" : !transform.any_op
+ }
+
+ // This will not be executed because it's not tagged.
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.test_print_remark_at_operand %arg0, "some other text that is not printed" : !transform.any_op
+ }
+
+ module {
+ module {}
+ // expected-remark @below {{payload}}
+ module attributes {transform.target_tag="payload"} {}
+ module {}
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
new file mode 100644
index 0000000000000..c3e5bff898de5
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-source.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s
+// No need to check anything else than parsing here, this is being used by another test as data.
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ transform.test_print_remark_at_operand %arg0, "outer" : !transform.any_op
+ transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
+ ^bb1(%arg1: !transform.any_op):
+ transform.test_print_remark_at_operand %arg1, "inner" : !transform.any_op
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external.mlir b/mlir/test/Dialect/Transform/test-interpreter-external.mlir
new file mode 100644
index 0000000000000..5ac6b66c817af
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-source.mlir})" \
+// RUN: --verify-diagnostics
+
+// The schedule in the separate file emits remarks at the payload root.
+
+// expected-remark @below {{outer}}
+// expected-remark @below {{inner}}
+module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b0b16abc0399c..f470606c78cf5 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1091,3 +1091,22 @@ transform.sequence failures(propagate) {
// expected-error @below {{attempting to assign a null parameter to this transform value}}
%0 = transform.test_produce_null_param : !transform.param<i64>
}
+
+// -----
+
+// expected-error @below {{could not find a nested top-level transform op}}
+// expected-note @below {{use the 'transform-file-name' option to provide transform as external file}}
+module {
+}
+
+// -----
+
+// expected-note @below {{previous top-level transform op}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+}
+
+// expected-error @below {{ore than one top-level transform op}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index 0119fb07d1c68..56c7031481b95 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -19,4 +19,5 @@ add_mlir_library(MLIRTestTransformDialect
MLIRPass
MLIRPDLDialect
MLIRTransformDialect
+ MLIRTransformDialectTransforms
)
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index 7d049eb98be51..afc7ed739fdc0 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -21,16 +22,22 @@ using namespace mlir;
namespace {
/// Simple pass that applies transform dialect ops directly contained in a
/// module.
+
+template <typename Derived>
+class ModulePassWrapper : public PassWrapper<Derived, OperationPass<ModuleOp>> {
+};
+
class TestTransformDialectInterpreterPass
- : public PassWrapper<TestTransformDialectInterpreterPass,
- OperationPass<ModuleOp>> {
+ : public transform::TransformInterpreterPassBase<
+ TestTransformDialectInterpreterPass, ModulePassWrapper> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTransformDialectInterpreterPass)
TestTransformDialectInterpreterPass() = default;
TestTransformDialectInterpreterPass(
- const TestTransformDialectInterpreterPass &) {}
+ const TestTransformDialectInterpreterPass &pass)
+ : TransformInterpreterPassBase(pass) {}
StringRef getArgument() const override {
return "test-transform-dialect-interpreter";
@@ -101,15 +108,12 @@ class TestTransformDialectInterpreterPass
getContext(), bindSecondExtraToParams, extraMappingStorage));
}
- ModuleOp module = getOperation();
- for (auto op :
- module.getBody()->getOps<transform::TransformOpInterface>()) {
- if (failed(transform::applyTransforms(
- module, op, extraMapping,
- transform::TransformOptions().enableExpensiveChecks(
- enableExpensiveChecks))))
- return signalPassFailure();
- }
+ options = options.enableExpensiveChecks(enableExpensiveChecks);
+ if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
+ getOperation(), getArgument(), getSharedTransformModule(),
+ extraMapping, options, transformFileName, debugPayloadRootTag,
+ debugTransformRootTag, getBinaryName())))
+ return signalPassFailure();
}
Option<bool> enableExpensiveChecks{
@@ -134,6 +138,25 @@ class TestTransformDialectInterpreterPass
*this, "bind-second-extra-to-params",
llvm::cl::desc("bind the second extra argument of the top-level op to "
"the given integer parameters")};
+ Option<std::string> transformFileName{
+ *this, "transform-file-name", llvm::cl::init(""),
+ llvm::cl::desc(
+ "Optional filename containing a transform dialect specification to "
+ "apply. If left empty, the IR is assumed to contain one top-level "
+ "transform dialect operation somewhere in the module.")};
+ Option<std::string> debugPayloadRootTag{
+ *this, "debug-payload-root-tag", llvm::cl::init(""),
+ llvm::cl::desc(
+ "Select the operation with 'transform.target_tag' attribute having "
+ "the given value as payload IR root. If empty select the pass anchor "
+ "operation as the payload IR root.")};
+ Option<std::string> debugTransformRootTag{
+ *this, "debug-transform-root-tag", llvm::cl::init(""),
+ llvm::cl::desc(
+ "Select the operation with 'transform.target_tag' attribute having "
+ "the given value as container IR for top-level transform ops. This "
+ "allows user control on what transformation to apply. If empty, "
+ "select the container of the top-level transform op.")};
};
struct TestTransformDialectEraseSchedulePass
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1534452a10eea..fc9a057111b04 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9204,8 +9204,10 @@ cc_library(
deps = [
":Analysis",
":IR",
+ ":Parser",
":Pass",
":SideEffectInterfaces",
+ ":Support",
":TransformDialect",
":TransformDialectTransformsIncGen",
"//llvm:Support",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index a3e96b5c633b7..01524c08c8ccf 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -319,6 +319,7 @@ cc_library(
"//mlir:PDLDialect",
"//mlir:Pass",
"//mlir:TransformDialect",
+ "//mlir:TransformDialectTransforms",
],
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
index 6e997e9f47e77..adfb3d0360203 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
@@ -13,7 +13,12 @@ package(default_visibility = ["//visibility:public"])
"//mlir:mlir-opt",
"//mlir:mlir-translate",
"//mlir/test:lit_data",
- ],
+ ] + glob([
+ "Transform/*-source.mlir",
+ ])
+ )
+ for src in glob(
+ include=["**/*.mlir"],
+ exclude=["Transform/*-source.mlir"]
)
- for src in glob(["**/*.mlir"])
]
More information about the Mlir-commits
mailing list