[Mlir-commits] [mlir] [mlir] Add `MLIRContext::executeCriticalSection` and `Pass::getOpDependentDialects` methods. (PR #98953)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 13:53:08 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Fabian Mora (fabianmcg)

<details>
<summary>Changes</summary>

This patch adds the `MLIRContext::executeCriticalSection` and `Pass::getOpDependentDialects` methods. The `MLIRContext::executeCriticalSection` method allows executing critical sections with respect to the `MLIRContext`. This method is required by `Pass::getOpDependentDialects`.

The `getOpDependentDialects` allows loading dependent dialects like the existing `Pass::getDependentDialects` method. However, this new method allows taking into consideration the operation being transformed.

Finally, the `DialectRegistry::empty` was added to avoid always executing the critical section in ` OpToOpPassAdaptor::run ` that loads the dialects.

---
Full diff: https://github.com/llvm/llvm-project/pull/98953.diff


7 Files Affected:

- (modified) mlir/include/mlir/IR/DialectRegistry.h (+3) 
- (modified) mlir/include/mlir/IR/MLIRContext.h (+4) 
- (modified) mlir/include/mlir/Pass/Pass.h (+9) 
- (modified) mlir/lib/IR/MLIRContext.cpp (+21) 
- (modified) mlir/lib/Pass/Pass.cpp (+15) 
- (modified) mlir/unittests/Pass/CMakeLists.txt (+1) 
- (modified) mlir/unittests/Pass/PassManagerTest.cpp (+37) 


``````````diff
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index c13a1a1858eb1..901eebc2b18de 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -253,6 +253,9 @@ class DialectRegistry {
   /// contains all of the components of this registry.
   bool isSubsetOf(const DialectRegistry &rhs) const;
 
+  /// Returns true if the registry is empty.
+  bool empty() const { return registry.empty() && extensions.empty(); }
+
 private:
   MapTy registry;
   std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e68..558ab74512c74 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -240,6 +240,10 @@ class MLIRContext {
   /// (attributes, operations, types, etc.).
   llvm::hash_code getRegistryHash();
 
+  /// Execute a critical section guarded by the context. This method guarantees
+  /// that calling `function` is thread-safe with respect to the context.
+  void executeCriticalSection(function_ref<void()> function);
+
   //===--------------------------------------------------------------------===//
   // Action API
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7725a3a2910bd..50cb8d2d69632 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -70,6 +70,15 @@ class Pass {
   /// register the Affine dialect but does not need to register Linalg.
   virtual void getDependentDialects(DialectRegistry &registry) const {}
 
+  /// Register dependent dialects for the current pass and operation being
+  /// transformed. This function is similar to `getDependentDialects` except
+  /// that it also receives the operation being transformed. When possible, use
+  /// `getDependentDialects` as this method incurs in extra synchronization
+  /// overhead. No transformations to `op` should be performed during this
+  /// method.
+  virtual void getOpDependentDialects(Operation *op,
+                                      DialectRegistry &registry) const {}
+
   /// Return the command line argument used when registering this pass. Return
   /// an empty string if one does not exist.
   virtual StringRef getArgument() const { return ""; }
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 214b354c5347e..2967d54ae3144 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -199,6 +199,9 @@ class MLIRContextImpl {
   /// A mutex used when accessing operation information.
   llvm::sys::SmartRWMutex<true> operationInfoMutex;
 
+  /// A mutex used when running critical sections.
+  llvm::sys::SmartMutex<true> criticalSectionMutex;
+
   //===--------------------------------------------------------------------===//
   // Affine uniquing
   //===--------------------------------------------------------------------===//
@@ -703,6 +706,24 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
   return RegisteredOperationName::lookup(name, this).has_value();
 }
 
+void MLIRContext::executeCriticalSection(function_ref<void()> function) {
+  if (!function)
+    return;
+  llvm::sys::SmartScopedLock<true> lock(impl->criticalSectionMutex);
+#ifndef NDEBUG
+  // Temporarily disable `multiThreadedExecutionContext` so the context is aware
+  // only a single thread is running.
+  int multiThreadedExecutionContext = impl->multiThreadedExecutionContext;
+  impl->multiThreadedExecutionContext = 0;
+#endif
+  // Execute the critical section.
+  function();
+#ifndef NDEBUG
+  // Reset `multiThreadedExecutionContext` to its original state.
+  impl->multiThreadedExecutionContext = multiThreadedExecutionContext;
+#endif
+}
+
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 57a6c20141d2c..73345e31da7f5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -513,6 +513,21 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   };
   pass->passState.emplace(op, am, dynamicPipelineCallback);
 
+  // Get any op dependent dialects.
+  DialectRegistry dependentDialects;
+  pass->getOpDependentDialects(op, dependentDialects);
+  MLIRContext &context = pass->getContext();
+  // Only append a non-empty non-trivial registry.
+  if (!dependentDialects.empty() &&
+      !dependentDialects.isSubsetOf(context.getDialectRegistry())) {
+    context.executeCriticalSection([&]() {
+      // Load the dialects.
+      context.appendDialectRegistry(dependentDialects);
+      for (StringRef name : dependentDialects.getDialectNames())
+        context.getOrLoadDialect(name);
+    });
+  }
+
   // Instrument before the pass has run.
   if (pi)
     pi->runBeforePass(pass, op);
diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index 802b3bbc6c635..0bc3238c342b9 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_unittest(MLIRPassTests
 target_link_libraries(MLIRPassTests
   PRIVATE
   MLIRDebug
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRPass)
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7ceed3bb3bc3b..eed98a5fe61c9 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
 #include "mlir/Debug/ExecutionContext.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -93,6 +94,12 @@ struct AddAttrFunctionPass
     : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
 
+  void getOpDependentDialects(Operation *op,
+                              DialectRegistry &registry) const override {
+    if (op->hasAttr("didProcess"))
+      registry.insert<arith::ArithDialect>();
+  }
+
   void runOnOperation() override {
     func::FuncOp op = getOperation();
     Builder builder(op->getParentOfType<ModuleOp>());
@@ -281,4 +288,34 @@ TEST(PassManagerTest, PassInitialization) {
   EXPECT_TRUE(succeeded(pm.run(module.get())));
 }
 
+TEST(PassManagerTest, OpDependentDialects) {
+  MLIRContext context;
+  context.loadDialect<func::FuncDialect>();
+  Builder builder(&context);
+
+  // Create a module with 1 function.
+  OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+  auto f =
+      func::FuncOp::create(builder.getUnknownLoc(), "function",
+                           builder.getFunctionType(std::nullopt, std::nullopt));
+  f.setPrivate();
+  module->push_back(f);
+
+  // Instantiate and run our pass the first time.
+  {
+    auto pm = PassManager::on<ModuleOp>(&context);
+    pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+    EXPECT_TRUE(succeeded(pm.run(module.get())));
+  }
+  ASSERT_EQ(context.getLoadedDialect<arith::ArithDialect>(), nullptr);
+
+  // Run the pass a second time, this time it should load the arith dialect.
+  {
+    auto pm = PassManager::on<ModuleOp>(&context);
+    pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+    EXPECT_TRUE(succeeded(pm.run(module.get())));
+  }
+  ASSERT_NE(context.getLoadedDialect<arith::ArithDialect>(), nullptr);
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/98953


More information about the Mlir-commits mailing list