[llvm] [mlir][Transform] Create a transform interpreter and a preloader pass (PR #68661)

Nicolas Vasilache via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 10 11:36:37 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/68661

>From c73a71e846bb304bbf48b3fb74d2c2d24c09369e Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Fri, 6 Oct 2023 09:53:18 +0000
Subject: [PATCH 1/5] [mlir][Transform] Create a transform interpreter and a
 preloader pass

This revision provides the ability to use an arbitrary named sequence op as
the entry point to a transform dialect strategy.

It is also a step towards better transform dialect usage in pass pipelines
that need to preload a transform library rather thanparse it on the fly.

The interpreter itself is significantly simpler than its testing counterpart
by avoiding payload/debug root tags and multiple shared modules.

In the process, the NamedSequenceOp::apply function is adapted to allow it
being an entry point.

NamedSequenceOp is **not** extended to take the PossibleTopLevelTrait at this
time, because the implementation of the trait is specific to allowing one
top-level dangling op with a region such as SequenceOp or AlternativesOp.
In particular, the verifier of PossibleTopLevelTrait does not allow for an
empty body, which is necessary to declare a NamedSequenceOp that gets linked
in separately before application.

In the future, we should dispense with the PossibleTopLevelTrait altogether
and always enter the interpreter with a NamedSequenceOp.

Lastly, relevant TD linking utilities are moved to TransformInterpreterUtils
and reused from there.
---
 .../Dialect/Transform/Transforms/Passes.td    |  31 +++++
 .../Transforms/TransformInterpreterUtils.h    |  18 +++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  17 ++-
 .../Transform/Transforms/CMakeLists.txt       |   2 +
 .../Transform/Transforms/InterpreterPass.cpp  |  55 ++++++++
 .../Transforms/PreloadLibraryPass.cpp         |  41 ++++++
 .../TransformInterpreterPassBase.cpp          |  53 --------
 .../Transforms/TransformInterpreterUtils.cpp  | 125 ++++++++++++++++--
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir |  19 ++-
 .../Library/lower-to-llvm.mlir}               |   0
 .../mlir/test/Dialect/BUILD.bazel             |   2 +-
 11 files changed, 286 insertions(+), 77 deletions(-)
 create mode 100644 mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
 create mode 100644 mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
 rename mlir/test/Dialect/{LLVM/lower-to-llvm-transform-symbol-def.mlir => Transform/Library/lower-to-llvm.mlir} (100%)

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
index 2400066c8ad8c8b..45a26496335850a 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
@@ -42,4 +42,35 @@ def InferEffectsPass : Pass<"transform-infer-effects"> {
   }];
 }
 
+def PreloadLibraryPass : Pass<"transform-preload-library"> {
+  let summary = "preload transform dialect library";
+  let description = [{
+    This pass preloads a transform library and makes it available to a subsequent
+    transform interpreter passes. The preloading occurs into the Transform
+    dialect and thus provides very limited functionality that does not scale.
+
+    Warning: Only a single such pass should exist for a given MLIR context.
+    This is a temporary solution until a resource-based solution is available.
+    TODO: use a resource blob.
+  }];
+  let options = [
+    ListOption<"transformLibraryPaths", "transform-library-paths", "std::string",
+    "Optional paths to files with modules that should be merged into the "
+    "transform module to provide the definitions of external named sequences.">
+  ];
+}
+
+def InterpreterPass : Pass<"transform-interpreter"> {
+  let summary = "transform dialect interpreter";
+  let description = [{
+    This pass runs the transform dialect interpreter and applies the named
+    sequence transformation specified by the provided name (defaults to
+    `__transform_main`).
+  }];
+  let options = [
+    Option<"entryPoint", "entry-point", "std::string",
+           /*default=*/[{"__transform_main"}],
+           "Entry point of the pass pipeline.">,
+  ];
+}
 #endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
index 36c80e6fd61d3c1..3fc02267f26e9da 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
@@ -26,6 +26,16 @@ class Region;
 
 namespace transform {
 namespace detail {
+
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
+                                     MLIRContext *context,
+                                     SmallVectorImpl<std::string> &fileNames);
+
 /// Utility to parse and verify the content of a `transformFileName` MLIR file
 /// containing a transform dialect specification.
 LogicalResult
@@ -33,6 +43,14 @@ parseTransformModuleFromFile(MLIRContext *context,
                              llvm::StringRef transformFileName,
                              OwningOpRef<ModuleOp> &transformModule);
 
+/// Utility to parse, verify, aggregate and link the content of all mlir files
+/// nested under `transformLibraryPaths` and containing transform dialect
+/// specifications.
+LogicalResult
+assembleTransformLibraryFromPaths(MLIRContext *context,
+                                  ArrayRef<std::string> transformLibraryPaths,
+                                  OwningOpRef<ModuleOp> &transformModule);
+
 /// Utility to load a transform interpreter `module` from a module that has
 /// already been preloaded in the context.
 /// This mode is useful in cases where explicit parsing of a transform library
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0e20b379cc2a3e7..660879c29728022 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1761,8 +1761,21 @@ DiagnosedSilenceableFailure
 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
                                   transform::TransformResults &results,
                                   transform::TransformState &state) {
-  // Nothing to do here.
-  return DiagnosedSilenceableFailure::success();
+  if (getBody().empty()) {
+    emitOpError("Cannot apply a bodyless named sequence, did you forget to link?");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+  // Map the entry block argument to the list of operations.
+  // Note: this is the same implementation as PossibleTopLevelTransformOp but
+  // without attaching the interface / trait since that is tailored to a
+  // dangling top-level op that does not get "called".
+  auto scope = state.make_region_scope(getBody());
+  if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
+      state, this->getOperation(), getBody())))
+    return DiagnosedSilenceableFailure::definiteFailure();
+
+  return applySequenceBlock(getBody().front(), FailurePropagationMode::Propagate, state,
+                            results);
 }
 
 void transform::NamedSequenceOp::getEffects(
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index 8774a8b86fb0d91..f0f57874f5e7032 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
   InferEffects.cpp
+  InterpreterPass.cpp
+  PreloadLibraryPass.cpp
   TransformInterpreterPassBase.cpp
   TransformInterpreterUtils.cpp
 
diff --git a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
new file mode 100644
index 000000000000000..4a6921b8611bf46
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
@@ -0,0 +1,55 @@
+//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_INTERPRETERPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+namespace {
+class InterpreterPass
+    : public transform::impl::InterpreterPassBase<InterpreterPass> {
+public:
+  using Base::Base;
+
+  LogicalResult initialize(MLIRContext *context) override {
+    // TODO: use a resource blob.
+    ModuleOp transformModule =
+        transform::detail::getPreloadedTransformModule(context);
+    if (transformModule) {
+      sharedTransformModule =
+          std::make_shared<OwningOpRef<ModuleOp>>(transformModule.clone());
+    }
+    return success();
+  }
+
+  void runOnOperation() override {
+    if (failed(transform::applyTransformNamedSequence(
+            getOperation(), sharedTransformModule->get(), options.enableExpensiveChecks(true), entryPoint)))
+      return signalPassFailure();
+  }
+
+private:
+  /// Transform interpreter options.
+  transform::TransformOptions options;
+
+  /// The separate transform module to be used for transformations, shared
+  /// across multiple instances of the pass if it is applied in parallel to
+  /// avoid potentially expensive cloning. MUST NOT be modified after the pass
+  /// has been initialized.
+  std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
new file mode 100644
index 000000000000000..795e31637617fe4
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
@@ -0,0 +1,41 @@
+//===- PreloadLibraryPass.cpp - Pass to preload a transform library -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_PRELOADLIBRARYPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+namespace {
+class PreloadLibraryPass
+    : public transform::impl::PreloadLibraryPassBase<PreloadLibraryPass> {
+public:
+  using Base::Base;
+
+  LogicalResult initialize(MLIRContext *context) override {
+    OwningOpRef<ModuleOp> mergedParsedLibraries;
+    if (failed(transform::detail::assembleTransformLibraryFromPaths(
+            context, transformLibraryPaths, mergedParsedLibraries)))
+      return failure();
+    // TODO: use a resource blob.
+    auto *dialect = context->getOrLoadDialect<transform::TransformDialect>();
+    dialect->registerLibraryModule(std::move(mergedParsedLibraries));
+    return success();
+  }
+
+  void runOnOperation() override {}
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 5f35b6789dc94fe..538c81fe39fddb2 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -357,59 +357,6 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
                          extraMappings, options);
 }
 
-/// Expands the given list of `paths` to a list of `.mlir` files.
-///
-/// Each entry in `paths` may either be a regular file, in which case it ends up
-/// in the result list, or a directory, in which case all (regular) `.mlir`
-/// files in that directory are added. Any other file types lead to a failure.
-static LogicalResult
-expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
-                       SmallVectorImpl<std::string> &fileNames) {
-  for (const std::string &path : paths) {
-    auto loc = FileLineColLoc::get(context, path, 0, 0);
-
-    if (llvm::sys::fs::is_regular_file(path)) {
-      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
-      fileNames.push_back(path);
-      continue;
-    }
-
-    if (!llvm::sys::fs::is_directory(path)) {
-      return emitError(loc)
-             << "'" << path << "' is neither a file nor a directory";
-    }
-
-    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
-
-    std::error_code ec;
-    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
-         it != itEnd && !ec; it.increment(ec)) {
-      const std::string &fileName = it->path();
-
-      if (it->type() != llvm::sys::fs::file_type::regular_file) {
-        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
-                          << "'\n");
-        continue;
-      }
-
-      if (!StringRef(fileName).endswith(".mlir")) {
-        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
-                          << "' because it does not end with '.mlir'\n");
-        continue;
-      }
-
-      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
-      fileNames.push_back(fileName);
-    }
-
-    if (ec)
-      return emitError(loc) << "error while opening files in '" << path
-                            << "': " << ec.message();
-  }
-
-  return success();
-}
-
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     ArrayRef<std::string> transformLibraryPaths,
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 1a6ebdd16232e8a..eb6f83dc8330563 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FileSystem.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -31,6 +32,59 @@ using namespace mlir;
 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
+/// Expands the given list of `paths` to a list of `.mlir` files.
+///
+/// Each entry in `paths` may either be a regular file, in which case it ends up
+/// in the result list, or a directory, in which case all (regular) `.mlir`
+/// files in that directory are added. Any other file types lead to a failure.
+LogicalResult transform::detail::expandPathsToMLIRFiles(
+    ArrayRef<std::string> &paths, MLIRContext *context,
+    SmallVectorImpl<std::string> &fileNames) {
+  for (const std::string &path : paths) {
+    auto loc = FileLineColLoc::get(context, path, 0, 0);
+
+    if (llvm::sys::fs::is_regular_file(path)) {
+      LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
+      fileNames.push_back(path);
+      continue;
+    }
+
+    if (!llvm::sys::fs::is_directory(path)) {
+      return emitError(loc)
+             << "'" << path << "' is neither a file nor a directory";
+    }
+
+    LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
+
+    std::error_code ec;
+    for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
+         it != itEnd && !ec; it.increment(ec)) {
+      const std::string &fileName = it->path();
+
+      if (it->type() != llvm::sys::fs::file_type::regular_file) {
+        LLVM_DEBUG(DBGS() << "  Skipping non-regular file '" << fileName
+                          << "'\n");
+        continue;
+      }
+
+      if (!StringRef(fileName).endswith(".mlir")) {
+        LLVM_DEBUG(DBGS() << "  Skipping '" << fileName
+                          << "' because it does not end with '.mlir'\n");
+        continue;
+      }
+
+      LLVM_DEBUG(DBGS() << "  Adding '" << fileName << "' to list of files\n");
+      fileNames.push_back(fileName);
+    }
+
+    if (ec)
+      return emitError(loc) << "error while opening files in '" << path
+                            << "': " << ec.message();
+  }
+
+  return success();
+}
+
 LogicalResult transform::detail::parseTransformModuleFromFile(
     MLIRContext *context, llvm::StringRef transformFileName,
     OwningOpRef<ModuleOp> &transformModule) {
@@ -91,10 +145,51 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
   return nullptr;
 }
 
-/// Return whether `func1` can be merged into `func2`. For that to work `func1`
-/// has to be a declaration (aka has to be external) and `func2` either has to
-/// be a declaration as well, or it has to be public (otherwise, it wouldn't
-/// be visible by `func1`).
+LogicalResult transform::detail::assembleTransformLibraryFromPaths(
+    MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
+    OwningOpRef<ModuleOp> &transformModule) {
+  // Assemble list of library files.
+  SmallVector<std::string> libraryFileNames;
+  if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
+                                            libraryFileNames)))
+    return failure();
+
+  // Parse modules from library files.
+  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
+  for (const std::string &libraryFileName : libraryFileNames) {
+    OwningOpRef<ModuleOp> parsedLibrary;
+    auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
+    if (failed(transform::detail::parseTransformModuleFromFile(
+            context, libraryFileName, parsedLibrary)))
+      return emitError(loc) << "failed to parse transform library module";
+    parsedLibraries.push_back(std::move(parsedLibrary));
+  }
+
+  // Merge parsed libraries into one module.
+  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
+  OwningOpRef<ModuleOp> mergedParsedLibraries =
+      ModuleOp::create(loc, "__transform");
+  {
+    mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
+                                         UnitAttr::get(context));
+    IRRewriter rewriter(context);
+    // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
+    for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
+      if (failed(transform::detail::mergeSymbolsInto(
+              mergedParsedLibraries.get(), std::move(parsedLibrary))))
+        return mergedParsedLibraries->emitError()
+               << "failed to verify merged transform module";
+    }
+  }
+
+  transformModule = std::move(mergedParsedLibraries);
+  return success();
+}
+
+/// Return whether `func1` can be merged into `func2`. For that to work
+/// `func1` has to be a declaration (aka has to be external) and `func2`
+/// either has to be a declaration as well, or it has to be public (otherwise,
+/// it wouldn't be visible by `func1`).
 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
   return func1.isExternal() && (func2.isPublic() || func2.isExternal());
 }
@@ -281,8 +376,9 @@ transform::detail::mergeSymbolsInto(Operation *target,
       auto collidingFuncOp =
           cast<FunctionOpInterface>(collidingOp.getOperation());
 
-      // Both ops are in the target module now and can be treated symmetrically,
-      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+      // Both ops are in the target module now and can be treated
+      // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
+      // `collidingFuncOp`.
       if (!canMergeInto(funcOp, collidingFuncOp)) {
         std::swap(funcOp, collidingFuncOp);
       }
@@ -317,18 +413,25 @@ LogicalResult transform::applyTransformNamedSequence(
     const TransformOptions &options, StringRef entryPoint) {
   Operation *transformRoot =
       detail::findTransformEntryPoint(payload, transformModule, entryPoint);
-  if (!transformRoot)
-    return failure();
+  if (!transformRoot) {
+    return payload->emitError()
+           << "could not find transform entry point: " << entryPoint
+           << " in either payload or transform module";
+  }
 
   // `transformModule` may not be modified.
-  OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
   if (transformModule && !transformModule->isAncestor(transformRoot)) {
+    OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
     if (failed(detail::mergeSymbolsInto(
             SymbolTable::getNearestSymbolTable(transformRoot),
-            std::move(clonedTransformModule))))
-      return failure();
+            std::move(clonedTransformModule)))) {
+      return payload->emitError() << "failed to merge symbols";
+    }
   }
 
+  LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
+  LLVM_DEBUG(DBGS() << "To\n" << *payload << "\n");
+
   // Apply the transform to the IR, do not enforce top-level constraints.
   RaggedArray<MappedValue> noExtraMappings;
   return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index c25212fbe98782b..d7472e685a7d0a0 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -2,12 +2,11 @@
 
 // RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
 
-// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
-// RUN:   -test-transform-dialect-erase-schedule -cse \
+// RUN: mlir-opt %s \
+// RUN:   -transform-preload-library="transform-library-paths=%p/../Transform/Library/lower-to-llvm.mlir" \
+// RUN:   -transform-interpreter="entry-point=some_other_entry_point" -cse \
 // RUN: | FileCheck %s
 
-module attributes {transform.target_tag="payload"} {
-
 // Check that we properly lower to llvm memref operations that require to be
 // expanded first, like `memref.subview`.
 func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index)
@@ -51,14 +50,14 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
   return %1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
 }
 
-} // transform payload
-
 module @named_inclusion_in_named attributes { transform.with_named_sequence } {
-  transform.named_sequence private @lower_to_cpu(!transform.any_op {transform.consumed}) -> !transform.any_op
+  transform.named_sequence @lower_to_cpu(
+    !transform.any_op {transform.consumed}) -> !transform.any_op
 
-  transform.sequence failures(propagate) {
-  ^bb1(%toplevel_module: !transform.any_op):
-    %m2 = transform.include @lower_to_cpu failures(suppress) (%toplevel_module) 
+  transform.named_sequence @some_other_entry_point(
+      %toplevel_module: !transform.any_op {transform.consumed}) {
+    transform.include @lower_to_cpu failures(suppress) (%toplevel_module)
       : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
   }
 }
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-transform-symbol-def.mlir b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
similarity index 100%
rename from mlir/test/Dialect/LLVM/lower-to-llvm-transform-symbol-def.mlir
rename to mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
index 061404078bc32ca..421102abd5193df 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel
@@ -17,10 +17,10 @@ package(default_visibility = ["//visibility:public"])
             "//mlir/test:lit_data",
         ] + glob([
             "IRDL/*.irdl.mlir",
-            "LLVM/*-symbol-def.mlir",
             "Transform/*-source.mlir",
             "Transform/*-symbol-def.mlir",
             "Transform/*-symbol-decl-and-schedule.mlir",
+            "Transform/Library/*.mlir",
             "Transform/test-interpreter-library/*.mlir",
         ]),
     )

>From 1c42a6a0f98d6600ced88625ccca790ccefbf9b6 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Mon, 9 Oct 2023 08:40:04 +0000
Subject: [PATCH 2/5] Use apply_registered_pass on a func to avoid the
 transform self-application issue

---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp        | 11 +++++------
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir         |  4 ++--
 .../test/Dialect/Transform/Library/lower-to-llvm.mlir |  6 +++---
 3 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 660879c29728022..dcbaaf04c49490d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1761,10 +1761,9 @@ DiagnosedSilenceableFailure
 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
                                   transform::TransformResults &results,
                                   transform::TransformState &state) {
-  if (getBody().empty()) {
-    emitOpError("Cannot apply a bodyless named sequence, did you forget to link?");
-    return DiagnosedSilenceableFailure::definiteFailure();
-  }
+  if (isExternal())
+    return emitDefiniteFailure() << "unresolved external named sequence";
+
   // Map the entry block argument to the list of operations.
   // Note: this is the same implementation as PossibleTopLevelTransformOp but
   // without attaching the interface / trait since that is tailored to a
@@ -1774,8 +1773,8 @@ transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
       state, this->getOperation(), getBody())))
     return DiagnosedSilenceableFailure::definiteFailure();
 
-  return applySequenceBlock(getBody().front(), FailurePropagationMode::Propagate, state,
-                            results);
+  return applySequenceBlock(getBody().front(),
+                            FailurePropagationMode::Propagate, state, results);
 }
 
 void transform::NamedSequenceOp::getEffects(
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index d7472e685a7d0a0..d3048e071ceb2c5 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -52,10 +52,10 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
 
 module @named_inclusion_in_named attributes { transform.with_named_sequence } {
   transform.named_sequence @lower_to_cpu(
-    !transform.any_op {transform.consumed}) -> !transform.any_op
+    !transform.any_op {transform.readonly}) -> !transform.any_op
 
   transform.named_sequence @some_other_entry_point(
-      %toplevel_module: !transform.any_op {transform.consumed}) {
+      %toplevel_module: !transform.any_op {transform.readonly}) {
     transform.include @lower_to_cpu failures(suppress) (%toplevel_module)
       : (!transform.any_op) -> (!transform.any_op)
     transform.yield
diff --git a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
index 46fba4a51792784..b4b93237a81b632 100644
--- a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
+++ b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
@@ -4,7 +4,7 @@
 module @lower_module_to_cpu attributes { transform.with_named_sequence } {
 
 transform.named_sequence @lower_to_cpu(
-    %module: !transform.any_op {transform.consumed}) -> !transform.any_op {
+    %module: !transform.any_op {transform.readonly}) -> !transform.any_op {
 
   %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
   %f = transform.apply_registered_pass "convert-vector-to-scf" to %func : (!transform.any_op) -> !transform.any_op
@@ -32,8 +32,8 @@ transform.named_sequence @lower_to_cpu(
     partial_conversion
   } : !transform.any_op
 
-  %m2 = transform.apply_registered_pass "reconcile-unrealized-casts" to %module : (!transform.any_op) -> !transform.any_op
-  transform.yield %m2 : !transform.any_op
+  %f6 = transform.apply_registered_pass "reconcile-unrealized-casts" to %f5 : (!transform.any_op) -> !transform.any_op
+  transform.yield %module : !transform.any_op
 }
 
 } // transform module

>From d22addaf289bd87bbaaa5eda2f59c79da3a9faaa Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Mon, 9 Oct 2023 20:16:18 +0200
Subject: [PATCH 3/5] Update test and lower_to_llvm to work as expected when
 called from a toplevel named sequence

---
 mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp | 10 +++++++---
 .../test/Dialect/Transform/Library/lower-to-llvm.mlir | 11 ++++++++++-
 2 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 4f88b8522e54c80..91d8302150808d0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -60,9 +60,13 @@ ArrayRef<Operation *>
 transform::TransformState::getPayloadOpsView(Value value) const {
   const TransformOpMapping &operationMapping = getMapping(value).direct;
   auto iter = operationMapping.find(value);
-  assert(
-      iter != operationMapping.end() &&
-      "cannot find mapping for payload handle (param/value handle provided?)");
+
+  if (iter == operationMapping.end()) {
+    value.dump();
+    assert(false &&
+           "cannot find mapping for payload handle (param/value handle "
+           "provided?)");
+  }
   return iter->getSecond();
 }
 
diff --git a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
index b4b93237a81b632..5b4a410b18561ef 100644
--- a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
+++ b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
@@ -32,7 +32,16 @@ transform.named_sequence @lower_to_cpu(
     partial_conversion
   } : !transform.any_op
 
-  %f6 = transform.apply_registered_pass "reconcile-unrealized-casts" to %f5 : (!transform.any_op) -> !transform.any_op
+  // Need to rematch here because:
+  //   1. applying reconcile-unrealized-casts on the whole module yields the
+  //      transform applies to transform, when called from a named sequence, at
+  //      this time.
+  //   2. apply_conversion patterns consumes the func but does not produce 
+  //      a new llvm.func.
+  %f6 = transform.structured.match ops{["llvm.func"]} in %module 
+    : (!transform.any_op) -> !transform.any_op
+  %f7 = transform.apply_registered_pass "reconcile-unrealized-casts" to %f6
+    : (!transform.any_op) -> !transform.any_op
   transform.yield %module : !transform.any_op
 }
 

>From 7e284e18f9dd3c2d34b33f0433a8c750e46ebb8f Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Mon, 9 Oct 2023 20:17:20 +0200
Subject: [PATCH 4/5] Update test and lower_to_llvm to work as expected when
 called from a toplevel named sequence

---
 mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir          | 4 ++--
 mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
index d3048e071ceb2c5..a39704e629cb561 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
@@ -51,12 +51,12 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
 }
 
 module @named_inclusion_in_named attributes { transform.with_named_sequence } {
-  transform.named_sequence @lower_to_cpu(
+  transform.named_sequence @lower_to_llvm(
     !transform.any_op {transform.readonly}) -> !transform.any_op
 
   transform.named_sequence @some_other_entry_point(
       %toplevel_module: !transform.any_op {transform.readonly}) {
-    transform.include @lower_to_cpu failures(suppress) (%toplevel_module)
+    transform.include @lower_to_llvm failures(suppress) (%toplevel_module)
       : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
diff --git a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
index 5b4a410b18561ef..0ba50bd2362b34d 100644
--- a/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
+++ b/mlir/test/Dialect/Transform/Library/lower-to-llvm.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s
 
 /// Schedule to lower to LLVM.
-module @lower_module_to_cpu attributes { transform.with_named_sequence } {
+module @lower_module_to_llvm attributes { transform.with_named_sequence } {
 
-transform.named_sequence @lower_to_cpu(
+transform.named_sequence @lower_to_llvm(
     %module: !transform.any_op {transform.readonly}) -> !transform.any_op {
 
   %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op

>From cf603600fa903110a2ef69d47f02f48ca40a51d7 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Tue, 10 Oct 2023 18:34:59 +0000
Subject: [PATCH 5/5] Format

---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp            | 2 +-
 mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dcbaaf04c49490d..61f1195b21275aa 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1770,7 +1770,7 @@ transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
   // dangling top-level op that does not get "called".
   auto scope = state.make_region_scope(getBody());
   if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
-      state, this->getOperation(), getBody())))
+          state, this->getOperation(), getBody())))
     return DiagnosedSilenceableFailure::definiteFailure();
 
   return applySequenceBlock(getBody().front(),
diff --git a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
index 4a6921b8611bf46..7d25e1f8c6eb03e 100644
--- a/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
@@ -38,7 +38,8 @@ class InterpreterPass
 
   void runOnOperation() override {
     if (failed(transform::applyTransformNamedSequence(
-            getOperation(), sharedTransformModule->get(), options.enableExpensiveChecks(true), entryPoint)))
+            getOperation(), sharedTransformModule->get(),
+            options.enableExpensiveChecks(true), entryPoint)))
       return signalPassFailure();
   }
 



More information about the llvm-commits mailing list