[Mlir-commits] [mlir] [mlir] add an example of using transform dialect standalone (PR #82623)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 22 06:15:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

Transform dialect interpreter is designed to be usable outside of the pass pipeline, as the main program transformation driver, e.g., for languages with explicit schedules. Provide an example of such usage with a couple of tests.

---

Patch is 24.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82623.diff


12 Files Affected:

- (modified) mlir/examples/CMakeLists.txt (+1) 
- (added) mlir/examples/transform-opt/CMakeLists.txt (+26) 
- (added) mlir/examples/transform-opt/README.md (+40) 
- (added) mlir/examples/transform-opt/mlir-transform-opt.cpp (+342) 
- (modified) mlir/test/CMakeLists.txt (+1) 
- (added) mlir/test/Examples/transform-opt/empty.mlir (+12) 
- (added) mlir/test/Examples/transform-opt/external-decl.mlir (+18) 
- (added) mlir/test/Examples/transform-opt/external-def.mlir (+8) 
- (added) mlir/test/Examples/transform-opt/pass.mlir (+19) 
- (added) mlir/test/Examples/transform-opt/self-contained.mlir (+21) 
- (added) mlir/test/Examples/transform-opt/syntax-error.mlir (+5) 
- (modified) mlir/test/lit.cfg.py (+1) 


``````````diff
diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt
index d256bf1a5cbb13..2a1cac34d8c290 100644
--- a/mlir/examples/CMakeLists.txt
+++ b/mlir/examples/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(toy)
 add_subdirectory(transform)
+add_subdirectory(transform-opt)
 add_subdirectory(minimal-opt)
diff --git a/mlir/examples/transform-opt/CMakeLists.txt b/mlir/examples/transform-opt/CMakeLists.txt
new file mode 100644
index 00000000000000..8e23555d0b5d73
--- /dev/null
+++ b/mlir/examples/transform-opt/CMakeLists.txt
@@ -0,0 +1,26 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
+
+set(LIBS
+  MLIRAnalysis
+  MLIRIR
+  MLIRParser
+  MLIRSupport
+  MLIRTransformDialect
+  MLIRTransformDialectTransforms
+  MLIRTransforms
+  ${dialect_libs}
+  ${conversion_libs}
+  ${extension_libs}
+)
+
+add_mlir_tool(mlir-transform-opt
+  mlir-transform-opt.cpp
+
+  DEPENDS
+  ${LIBS}
+)
+target_link_libraries(mlir-transform-opt PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-transform-opt)
+mlir_check_all_link_libraries(mlir-transform-opt)
diff --git a/mlir/examples/transform-opt/README.md b/mlir/examples/transform-opt/README.md
new file mode 100644
index 00000000000000..e9c8cc0173c7b2
--- /dev/null
+++ b/mlir/examples/transform-opt/README.md
@@ -0,0 +1,40 @@
+# Standalone Transform Dialect Interpreter
+
+This is an example of using the Transform dialect interpreter functionality standalone, that is, outside of the regular pass pipeline. The example is a
+binary capable of processing MLIR source files similar to `mlir-opt` and other
+optimizer drivers, with the entire transformation process driven by a Transform
+dialect script. This script can be embedded into the source file or provided in
+a separate MLIR source file.
+
+Either the input module or the transform module must contain a top-level symbol
+named `__transform_main`, which is used as the entry point to the transformation
+script.
+
+```sh
+mlir-transform-opt payload_with_embedded_transform.mlir
+mlir-transform-opt payload.mlir -transform=transform.mlir
+```
+
+The name of the entry point can be overridden using command-line options.
+
+```sh
+mlir-transform-opt payload-mlir -transform-entry-point=another_entry_point
+```
+
+Transform scripts can reference symbols defined in other source files, called
+libraries, which can be supplied to the binary through command-line options.
+Libraries will be embedded into the main transformation module by the tool and
+the interpreter will process everything as a single module. A debug option is
+available to see the contents of the transform module before it goes into the interpreter.
+
+```sh
+mlir-transform-opt payload.mlir -transform=transform.mlir \
+  -transform-library=external_definitions_1.mlir \
+  -transform-library=external_definitions_2.mlir \
+  -dump-library-module
+```
+
+Check out the [Transform dialect
+tutorial](https://mlir.llvm.org/docs/Tutorials/transform/) as well as
+[documentation](https://mlir.llvm.org/docs/Dialects/Transform/) to learn more
+about the dialect. 
diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
new file mode 100644
index 00000000000000..3aa4e2485aea25
--- /dev/null
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -0,0 +1,342 @@
+//===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/InitAllExtensions.h"
+#include "mlir/InitAllPasses.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include <cstdlib>
+
+namespace {
+
+using namespace llvm;
+
+/// Structure containing command line options for the tool, these will get
+/// initialized when an instance is created.
+struct MlirTransformOptCLOptions {
+  cl::opt<bool> allowUnregisteredDialects{
+      "allow-unregistered-dialect",
+      cl::desc("Allow operations coming from an unregistered dialect"),
+      cl::init(false)};
+
+  cl::opt<bool> verifyDiagnostics{
+      "verify-diagnostics",
+      cl::desc("Check that emitted diagnostics match expected-* lines "
+               "on the corresponding line"),
+      cl::init(false)};
+
+  cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
+                                       cl::init("-")};
+
+  cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
+                                      cl::value_desc("filename"),
+                                      cl::init("-")};
+
+  cl::opt<std::string> transformMainFilename{
+      "transform",
+      cl::desc("File containing entry point of the transform script, if "
+               "different from the input file"),
+      cl::value_desc("filename"), cl::init("")};
+
+  cl::list<std::string> transformLibraryFilenames{
+      "transform-library", cl::desc("File(s) containing definitions of "
+                                    "additional transform script symbols")};
+
+  cl::opt<std::string> transformEntryPoint{
+      "transform-entry-point",
+      cl::desc("Name of the entry point transform symbol"),
+      cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
+                   .str())};
+
+  cl::opt<bool> disableExpensiveChecks{
+      "disable-expensive-checks",
+      cl::desc("Disables potentially expensive checks in the transform "
+               "interpreter, providing more speed at the expense of "
+               "potential memory problems and silent corruptions"),
+      cl::init(false)};
+
+  cl::opt<bool> dumpLibraryModule{
+      "dump-library-module",
+      cl::desc("Prints the combined library module before the output"),
+      cl::init(false)};
+};
+} // namespace
+
+/// "Managed" static instance of the command-line options structure. This makes
+/// them locally-scoped and explicitly initialized/deinitialized. While this is
+/// not strictly necessary in the tool source file that is not being used as a
+/// library (where the options would pollute the global list of options), it is
+/// good practice to follow this.
+static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
+
+/// Explicitly registers command-line options.
+static void registerCLOptions() { *clOptions; }
+
+namespace {
+/// MLIR has deeply rooted expectations that the LLVM source manager contains
+/// exactly one buffer, until at least the lexer level. This class wraps
+/// multiple LLVM source managers each managing a buffer to match MLIR's
+/// expectations while still providing a centralized handling mechanism.
+class TransformSourceMgr {
+public:
+  /// Constructs the source manager indicating whether diagnostic messages will
+  /// be verified later on.
+  explicit TransformSourceMgr(bool verifyDiagnostics)
+      : verifyDiagnostics(verifyDiagnostics) {}
+
+  /// Deconstructs the source manager. Note that `checkResults` must have been
+  /// called on this instance before deconstructing it.
+  ~TransformSourceMgr() {
+    assert(resultChecked && "must check the result of diagnostic handlers by "
+                            "running TransformSourceMgr::checkResult");
+  }
+
+  /// Parses the given buffer and creates the top-level operation of the kind
+  /// specified as template argument in the given context. Additional parsing
+  /// options may be provided.
+  template <typename OpTy = mlir::Operation *>
+  mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
+                                      mlir::MLIRContext &context,
+                                      const mlir::ParserConfig &config) {
+    // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
+    // the code below to capture a reference to the source manager in such a way
+    // that it is not invalidated when the vector contents is eventually
+    // reallocated.
+    llvm::SourceMgr &mgr =
+        *sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>());
+    mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
+
+    // Choose the type of diagnostic handler depending on whether diagnostic
+    // verification needs to happen and store it.
+    if (verifyDiagnostics) {
+      diagHandlers.push_back(
+          std::make_unique<mlir::SourceMgrDiagnosticVerifierHandler>(mgr,
+                                                                     &context));
+    } else {
+      diagHandlers.push_back(
+          std::make_unique<mlir::SourceMgrDiagnosticHandler>(mgr, &context));
+    }
+
+    // Defer to MLIR's parser.
+    return mlir::parseSourceFile<OpTy>(mgr, config);
+  }
+
+  /// If diagnostic message verification has been requested upon construction of
+  /// this source manager, performs the verification, reports errors and returns
+  /// the result of the verification. Otherwise passes through the given value.
+  mlir::LogicalResult checkResult(mlir::LogicalResult result) {
+    resultChecked = true;
+    if (!verifyDiagnostics)
+      return result;
+
+    return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) {
+      return mlir::failed(
+          static_cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler.get())
+              ->verify());
+    }));
+  }
+
+private:
+  /// Indicates whether diagnostic message verification is requested.
+  const bool verifyDiagnostics;
+
+  /// Indicates that diagnostic message verification has taken place, and the
+  /// deconstruction is therefore safe.
+  bool resultChecked = false;
+
+  /// Storage for per-buffer source managers and diagnostic handlers. These are
+  /// wrapped into unique pointers in order to make it safe to capture
+  /// references to these objects: if the vector is reallocated, the unique
+  /// pointer objects are moved by the pointer addresses won't change. Also, for
+  /// handlers, this allows to store the pointer to the base class.
+  SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
+  SmallVector<std::unique_ptr<mlir::SourceMgrDiagnosticHandler>> diagHandlers;
+};
+} // namespace
+
+/// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
+/// and doesn't enforce the entry point transform ops being top-level.
+static mlir::LogicalResult
+applyTransforms(mlir::Operation *payloadRoot,
+                mlir::transform::TransformOpInterface transformRoot,
+                const mlir::transform::TransformOptions &options) {
+  return applyTransforms(payloadRoot, transformRoot, {}, options,
+                         /*enforceToplevelTransformOp=*/false);
+}
+
+/// Applies transforms indicated in the transform dialect script to the input
+/// buffer. The transforms script may be embedded in the input buffer or as a
+/// separate buffer. The transform script may have external symbols, the
+/// definitions of which must be provided in transform library buffers. If the
+/// application is successful, prints the transformed input buffer into the
+/// given output stream. Additional configuration options are derived from
+/// command-line options.
+static mlir::LogicalResult processPayloadBuffer(
+    raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
+    std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
+    MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
+    mlir::DialectRegistry &registry) {
+
+  // Initialize the MLIR context, and various configurations.
+  mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
+  context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects);
+  mlir::ParserConfig config(&context);
+  TransformSourceMgr sourceMgr(
+      /*verifyDiagnostics=*/clOptions->verifyDiagnostics);
+
+  // Parse the input buffer that will be used as transform payload.
+  mlir::OwningOpRef<mlir::Operation *> payloadRoot =
+      sourceMgr.parseBuffer(std::move(inputBuffer), context, config);
+  if (!payloadRoot)
+    return sourceMgr.checkResult(mlir::failure());
+
+  // Identify the module containing the transform script entry point. This may
+  // be the same module as the input or a separate module. In the former case,
+  // make a copy of the module so it can be modified freely. Modification may
+  // happen in the script itself (at which point it could be rewriting itself
+  // during interpretation, leading to tricky memory errors) or by embedding
+  // library modules in the script.
+  mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
+  if (transformBuffer) {
+    transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
+        std::move(transformBuffer), context, config);
+    if (!transformRoot)
+      return sourceMgr.checkResult(mlir::failure());
+  } else {
+    transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
+  }
+
+  // Parse and merge the libraries into the main transform module.
+  for (auto &&transformLibrary : transformLibraries) {
+    mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
+        sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary),
+                                              context, config);
+
+    if (!libraryModule ||
+        mlir::failed(mlir::transform::detail::mergeSymbolsInto(
+            *transformRoot, std::move(libraryModule))))
+      return sourceMgr.checkResult(mlir::failure());
+  }
+
+  // If requested, dump the combined transform module.
+  if (clOptions->dumpLibraryModule)
+    transformRoot->dump();
+
+  // Find the entry point symbol. Even if it had originally been in the payload
+  // module, it was cloned into the transform module so only look there.
+  mlir::transform::TransformOpInterface entryPoint =
+      mlir::transform::detail::findTransformEntryPoint(
+          *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
+  if (!entryPoint)
+    return sourceMgr.checkResult(mlir::failure());
+
+  // Apply the requested transformations.
+  mlir::transform::TransformOptions transformOptions;
+  transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks);
+  if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions)))
+    return sourceMgr.checkResult(mlir::failure());
+
+  // Print the transformed result and check the captured diagnostics if
+  // requested.
+  payloadRoot->print(os);
+  return sourceMgr.checkResult(mlir::success());
+}
+
+/// Tool entry point.
+static mlir::LogicalResult runMain(int argc, char **argv) {
+  // Register all upstream dialects and extensions. Specific uses are advised
+  // not to register all dialects indiscriminately but rather hand-pick what is
+  // necessary for their use case.
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  mlir::registerAllExtensions(registry);
+  mlir::registerAllPasses();
+
+  // Explicitly register the transform dialect. This is not strictly necessary
+  // since it has been already registered as part of the upstream dialect list,
+  // but useful for example purposes for cases when dialects to register are
+  // hand-picked. The transform dialect must be registered.
+  registry.insert<mlir::transform::TransformDialect>();
+
+  // Register various command-line options. Note that the LLVM initializer
+  // object is a RAII that ensures correct deconstruction of command-line option
+  // objects inside ManagedStatic.
+  llvm::InitLLVM y(argc, argv);
+  mlir::registerAsmPrinterCLOptions();
+  mlir::registerMLIRContextCLOptions();
+  registerCLOptions();
+  llvm::cl::ParseCommandLineOptions(argc, argv,
+                                    "Minimal Transform dialect driver\n");
+
+  // Try opening the main input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> payloadFile =
+      mlir::openInputFile(clOptions->payloadFilename, &errorMessage);
+  if (!payloadFile) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::failure();
+  }
+
+  // Try opening the output file.
+  std::unique_ptr<llvm::ToolOutputFile> outputFile =
+      mlir::openOutputFile(clOptions->outputFilename, &errorMessage);
+  if (!outputFile) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::failure();
+  }
+
+  // Try opening the main transform file if provided.
+  std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
+  if (!clOptions->transformMainFilename.empty()) {
+    if (clOptions->transformMainFilename == clOptions->payloadFilename) {
+      llvm::errs() << "warning: " << clOptions->payloadFilename
+                   << " is provided as both payload and transform file\n";
+    } else {
+      transformRootFile =
+          mlir::openInputFile(clOptions->transformMainFilename, &errorMessage);
+      if (!transformRootFile) {
+        llvm::errs() << errorMessage << "\n";
+        return mlir::failure();
+      }
+    }
+  }
+
+  // Try opening transform library files if provided.
+  SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
+  transformLibraries.reserve(clOptions->transformLibraryFilenames.size());
+  for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
+    transformLibraries.emplace_back(
+        mlir::openInputFile(filename, &errorMessage));
+    if (!transformLibraries.back()) {
+      llvm::errs() << errorMessage << "\n";
+      return mlir::failure();
+    }
+  }
+
+  return processPayloadBuffer(outputFile->os(), std::move(payloadFile),
+                              std::move(transformRootFile), transformLibraries,
+                              registry);
+}
+
+int main(int argc, char **argv) {
+  return mlir::asMainReturnCode(runMain(argc, argv));
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 74921544c55578..baf07ea1f010ac 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -173,6 +173,7 @@ if(LLVM_BUILD_EXAMPLES)
     transform-opt-ch3
     transform-opt-ch4
     mlir-minimal-opt
+    mlir-transform-opt
     )
   if(MLIR_ENABLE_EXECUTION_ENGINE)
     list(APPEND MLIR_TEST_DEPENDS
diff --git a/mlir/test/Examples/transform-opt/empty.mlir b/mlir/test/Examples/transform-opt/empty.mlir
new file mode 100644
index 00000000000000..b525769db68822
--- /dev/null
+++ b/mlir/test/Examples/transform-opt/empty.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir | FileCheck %s
+// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --verify-diagnostics
+// RUN: mlir-transform-opt %s --transform=%p/external-def.mlir --transform-entry-point=external_def | FileCheck %s --check-prefix=EXTERNAL
+// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --transform-library=%p/external-def.mlir | FileCheck %s --check-prefix=EXTERNAL
+// RUN: mlir-transform-opt %s --transform=%p/syntax-error.mlir --verify-diagnostics
+// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir --transform-library=%p/syntax-error.mlir --verify-diagnostics
+// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir --transform-library=%p/external-def.mlir --transform-library=%p/syntax-error.mlir --verify-diagnostics
+
+// CHECK: IR printer: in self-contained
+// EXTERNAL: IR printer: external_def
+// CHECK-NOT: @__transform_main
+module {}
diff --git a/mlir/test/Examples/transform-opt/external-decl.mlir b/mlir/test/Examples/transform-opt/external-decl.mlir
new file mode 100644
index 00000000000000..5a73735892429b
--- /dev/null
+++ b/mlir/test/Examples/transform-opt/external-decl.mlir
@@ -0,0 +1,18 @@
+// This test just needs to parse. Note that the diagnostic message below will
+// be produced in *another* multi-file ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82623


More information about the Mlir-commits mailing list