[Mlir-commits] [mlir] [mlir] add MlirOptMain config callback for context configuration (PR #68228)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Oct 4 08:15:12 PDT 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/68228

Add a MlirOptMainConfig option letting the caller provide a callback performing additional configuration of the MLIRContext object, similarly to how they can supply a callback for pass manager configuration.

Clients of MlirOptMain, especially downstream, may want to perform additional configuration of the context before running the pass pipeline and are currently unable to do so. This configuration may include preloading dialects (e.g., if a downstream wishes to make some of its dialects "built-in"), diagnostic engine configuration and resource management in context.

Exercise this functionality in the transform dialect from a unit test by letting it preload symbols.

>From 9101578dcb4e970abce043c8340f6550eb86f360 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 4 Oct 2023 14:53:06 +0000
Subject: [PATCH] [mlir] add MlirOptMain config callback for context
 configuration

Add a MlirOptMainConfig option letting the caller provide a callback
performing additional configuration of the MLIRContext object, similarly
to how they can supply a callback for pass manager configuration.

Clients of MlirOptMain, especially downstream, may want to perform
additional configuration of the context before running the pass pipeline
and are currently unable to do so. This configuration may include
preloading dialects (e.g., if a downstream wishes to make some of its
dialects "built-in"), diagnostic engine configuration and resource
management in context.

Exercise this functionality in the transform dialect from a unit test by
letting it preload symbols.
---
 .../Dialect/Transform/IR/TransformDialect.td  |  23 ++-
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h |  17 ++
 .../TransformInterpreterPassBase.cpp          |  21 +-
 mlir/lib/Pass/PassManagerOptions.cpp          |   6 +-
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |   3 +
 .../TestTransformDialectInterpreter.cpp       |   4 +
 mlir/unittests/CMakeLists.txt                 |   1 +
 mlir/unittests/Tools/CMakeLists.txt           |  11 +
 .../unittests/Tools/MlirOptContextPreload.cpp | 192 ++++++++++++++++++
 9 files changed, 272 insertions(+), 6 deletions(-)
 create mode 100644 mlir/unittests/Tools/CMakeLists.txt
 create mode 100644 mlir/unittests/Tools/MlirOptContextPreload.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 3448e27a41a6804..16739ac3e19336c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -63,6 +63,22 @@ def Transform_Dialect : Dialect {
       using ExtensionTypePrintingHook =
           std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
 
+      /// Appends the given module as a transform symbol library available to
+      /// all dialect users.
+      void registerLibraryModule(
+          ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
+        libraryModules.push_back(std::move(library));
+      }
+
+      /// Returns a range of registered library modules.
+      auto getLibraryModules() const {
+        return ::llvm::map_range(
+              libraryModules,
+              [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
+                return library.get();
+              });
+      }
+
     private:
       /// Registers operations specified as template parameters with this
       /// dialect. Checks that they implement the required interfaces.
@@ -120,7 +136,12 @@ def Transform_Dialect : Dialect {
       /// A map from type TypeID to its printing function. No need to do string
       /// lookups when the type is fully constructed.
       ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
-      typePrintingHooks;
+          typePrintingHooks;
+
+      /// Modules containing symbols, e.g. named sequences, that will be
+      /// resolved by the interpreter when used.
+      ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
+          libraryModules;
   }];
 }
 
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 222a51e8db77eac..06ca84210698c11 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -61,6 +61,19 @@ class MlirOptMainConfig {
     return allowUnregisteredDialectsFlag;
   }
 
+  /// Set the callback to configure the context after processing the regular
+  /// dialect registry, e.g., for preloading a subset of dialects.
+  template <typename FuncTy>
+  MlirOptMainConfig &setContextConfigurationFn(FuncTy &&fn) {
+    contextConfigurationCallback = fn;
+    return *this;
+  }
+  LogicalResult setupContext(MLIRContext &context) const {
+    if (contextConfigurationCallback)
+      return contextConfigurationCallback(context);
+    return success();
+  }
+
   /// Set the debug configuration to use.
   MlirOptMainConfig &setDebugConfig(tracing::DebugConfig config) {
     debugConfig = std::move(config);
@@ -176,6 +189,10 @@ class MlirOptMainConfig {
   /// general.
   bool allowUnregisteredDialectsFlag = false;
 
+  /// The callback for additional configuration of the MLIR context after it has
+  /// been created and populated from the registry.
+  std::function<LogicalResult(MLIRContext &)> contextConfigurationCallback;
+
   /// Configuration for the debugging hooks.
   tracing::DebugConfig debugConfig;
 
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 23640c92457a89d..4da0f22f734686c 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -477,6 +477,14 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
     return failure();
 
+  auto libraryRange = context->getOrLoadDialect<transform::TransformDialect>()
+                          ->getLibraryModules();
+  if (!transformLibraryFileName.empty() && !libraryRange.empty()) {
+    return emitError((*libraryRange.begin()).getLoc())
+           << "library already supplied through the dialect, cannot parse "
+              "another library as requested by pass flags";
+  }
+
   OwningOpRef<ModuleOp> parsedLibraryModule;
   if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
                                           parsedLibraryModule)))
@@ -502,16 +510,21 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     }
   }
 
-  if (!parsedLibraryModule || !*parsedLibraryModule)
+  bool hasDialectLevelLibrary = !libraryRange.empty();
+  if (!hasDialectLevelLibrary && (!parsedLibraryModule || !*parsedLibraryModule))
     return success();
 
   if (sharedTransformModule && *sharedTransformModule) {
     if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
-                                     parsedLibraryModule.get())))
+                                     hasDialectLevelLibrary
+                                         ? *libraryRange.begin()
+                                         : parsedLibraryModule.get())))
       return failure();
   } else {
-    transformLibraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
+    transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        hasDialectLevelLibrary
+            ? OwningOpRef<ModuleOp>((*libraryRange.begin()).clone())
+            : std::move(parsedLibraryModule));
   }
   return success();
 }
diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index ffc53b7e3ed0236..774c5e10c154cda 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -131,8 +131,12 @@ void mlir::registerPassManagerCLOptions() {
 }
 
 LogicalResult mlir::applyPassManagerCLOptions(PassManager &pm) {
-  if (!options.isConstructed())
+  if (!options.isConstructed()) {
+    emitError(UnknownLoc::get(pm.getContext()))
+        << "could not apply pass manager command line options.\n"
+           "Missing 'registerPassManagerCLOptions' call?\n";
     return failure();
+  }
 
   // Generate a reproducer on crash/failure.
   if (options->reproducerFile.getNumOccurrences())
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 644113058bdc1cc..87fcafe380302fb 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -326,6 +326,9 @@ static LogicalResult
 performActions(raw_ostream &os,
                const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
                MLIRContext *context, const MlirOptMainConfig &config) {
+  if (failed(config.setupContext(*context)))
+    return failure();
+
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..c2da6b3f227a8ff 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -264,5 +264,9 @@ void registerTestTransformDialectEraseSchedulePass() {
 void registerTestTransformDialectInterpreterPass() {
   PassRegistration<TestTransformDialectInterpreterPass> reg;
 }
+/// Creates an instance of the pass for applying transform dialect ops.
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass() {
+  return std::make_unique<TestTransformDialectInterpreterPass>();
+}
 } // namespace test
 } // namespace mlir
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index d0e222091c9f896..5f6e54dccbd1800 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(Support)
 add_subdirectory(Rewrite)
 add_subdirectory(TableGen)
 add_subdirectory(Target)
+add_subdirectory(Tools)
 add_subdirectory(Transforms)
 
 if(MLIR_ENABLE_EXECUTION_ENGINE)
diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt
new file mode 100644
index 000000000000000..7df0ab05caba16f
--- /dev/null
+++ b/mlir/unittests/Tools/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_unittest(MLIRToolsMlirOptContextPreload
+  MlirOptContextPreload.cpp
+)
+
+target_link_libraries(MLIRToolsMlirOptContextPreload
+  PRIVATE
+  MLIROptLib
+  MLIRParser
+  MLIRTransformDialect
+  MLIRTestTransformDialect
+)
diff --git a/mlir/unittests/Tools/MlirOptContextPreload.cpp b/mlir/unittests/Tools/MlirOptContextPreload.cpp
new file mode 100644
index 000000000000000..ac7796a1e9ad62e
--- /dev/null
+++ b/mlir/unittests/Tools/MlirOptContextPreload.cpp
@@ -0,0 +1,192 @@
+//===- MlirOptContextPreload.cpp - Test MlirOptMain parameterization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass that fails if the transform dialect is not loaded in the context.
+/// Sets the flag, the reference to which is passed into the constructor, when
+/// runs to check for lack-of-failure because of the pass not running at all.
+struct CheckIfTransformIsLoadedPass
+    : public PassWrapper<CheckIfTransformIsLoadedPass, OperationPass<>> {
+  explicit CheckIfTransformIsLoadedPass(bool &hasRun) : hasRun(hasRun) {}
+
+  void runOnOperation() override {
+    hasRun = true;
+    if (!getContext().getLoadedDialect<transform::TransformDialect>())
+      return signalPassFailure();
+  }
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CheckIfTransformIsLoadedPass)
+
+private:
+  bool &hasRun;
+};
+} // namespace
+
+TEST(MlirOptMain, ContextPreloadDialect) {
+  registerPassManagerCLOptions();
+
+  MlirOptMainConfig config;
+
+  // Make sure the transform dialect is unconditionally loaded.
+  config.setContextConfigurationFn([](MLIRContext &context) {
+    context.loadDialect<transform::TransformDialect>();
+    return success();
+  });
+
+  // Configure the pass pipeline.
+  bool hasRun = false;
+  config.setPassPipelineSetupFn([&](PassManager &pm) {
+    pm.addPass(std::make_unique<CheckIfTransformIsLoadedPass>(hasRun));
+    return success();
+  });
+
+  // Main.
+  DialectRegistry registry;
+  LogicalResult mainResult = MlirOptMain(
+      llvm::nulls(), llvm::MemoryBuffer::getMemBuffer("", "<empty_buffer>"),
+      registry, config);
+
+  EXPECT_TRUE(succeeded(mainResult));
+  EXPECT_TRUE(hasRun) << "pass has not run";
+}
+
+TEST(MlirOptMain, ContextPreloadDialectNotLoaded) {
+  registerPassManagerCLOptions();
+
+  // Configure the pass pipeline, but do not load the transform dialect
+  // unconditionally. The pass should run and fail.
+  MlirOptMainConfig config;
+  bool hasRun = false;
+  config.setPassPipelineSetupFn([&](PassManager &pm) {
+    pm.addPass(std::make_unique<CheckIfTransformIsLoadedPass>(hasRun));
+    return success();
+  });
+
+  // Main.
+  DialectRegistry registry;
+  LogicalResult mainResult = MlirOptMain(
+      llvm::nulls(), llvm::MemoryBuffer::getMemBuffer("", "<empty_buffer>"),
+      registry, config);
+
+  EXPECT_FALSE(succeeded(mainResult));
+  EXPECT_TRUE(hasRun) << "pass has not run";
+}
+
+TEST(MlirOptMain, ContextPreloadDialectFailure) {
+  registerPassManagerCLOptions();
+
+  // Return failure when configuring the context. The pass should not run.
+  MlirOptMainConfig config;
+  config.setContextConfigurationFn(
+      [](MLIRContext &context) { return failure(); });
+  bool hasRun = false;
+  config.setPassPipelineSetupFn([&](PassManager &pm) {
+    pm.addPass(std::make_unique<CheckIfTransformIsLoadedPass>(hasRun));
+    return success();
+  });
+
+  // Main.
+  DialectRegistry registry;
+  LogicalResult mainResult = MlirOptMain(
+      llvm::nulls(), llvm::MemoryBuffer::getMemBuffer("", "<empty_buffer>"),
+      registry, config);
+
+  EXPECT_FALSE(succeeded(mainResult));
+  EXPECT_FALSE(hasRun) << "pass was not expected to run";
+}
+
+namespace mlir {
+namespace test {
+std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
+} // namespace test
+} // namespace mlir
+namespace test {
+void registerTestTransformDialectExtension(DialectRegistry &registry);
+} // namespace test
+
+const static llvm::StringLiteral library = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence public @print_remark(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
+    transform.yield
+  }
+})MLIR";
+
+const static llvm::StringLiteral input = R"MLIR(
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @print_remark(%arg0: !transform.any_op {transform.readonly})
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_remark failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+})MLIR";
+
+TEST(MlirOptMain, ContextPreloadConstructedLibrary) {
+  registerPassManagerCLOptions();
+
+  // Make sure the transform dialect is always loaded and make it own a library
+  // module that will be used by the pass.
+  bool emittedDiagnostic = false;
+  MlirOptMainConfig config;
+  config.setContextConfigurationFn([&](MLIRContext &context) {
+    auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
+
+    ParserConfig parserConfig(&context);
+    OwningOpRef<ModuleOp> transformLibrary = parseSourceString<ModuleOp>(
+        library, parserConfig, "<transform-library>");
+    if (!transformLibrary)
+      return failure();
+
+    dialect->registerLibraryModule(std::move(transformLibrary));
+
+    context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
+      if (diag.getSeverity() == DiagnosticSeverity::Remark &&
+          diag.str() == "from external symbol") {
+        emittedDiagnostic = true;
+      }
+    });
+
+    return success();
+  });
+
+  // Pass pipeline configuration.
+  config.setPassPipelineSetupFn([](PassManager &pm) {
+    pm.addPass(mlir::test::createTestTransformDialectInterpreterPass());
+    return success();
+  });
+
+  // We need to register the test extension since the input contains its ops.
+  DialectRegistry registry;
+  ::test::registerTestTransformDialectExtension(registry);
+  std::string fileErrorMessage;
+  std::string output;
+  llvm::raw_string_ostream os(output);
+  LogicalResult mainResult = MlirOptMain(
+      os, llvm::MemoryBuffer::getMemBuffer(input, "<input>"), registry, config);
+
+  ASSERT_TRUE(fileErrorMessage.empty()) << fileErrorMessage;
+  EXPECT_TRUE(succeeded(mainResult));
+  EXPECT_TRUE(emittedDiagnostic)
+      << "did not produce the expected diagnostic from external symbol";
+}



More information about the Mlir-commits mailing list