[Mlir-commits] [mlir] [mlir] Add `MLIRContext::executeCriticalSection` and `Pass::getOpDependentDialects` methods. (PR #98953)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 15 13:53:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Fabian Mora (fabianmcg)
<details>
<summary>Changes</summary>
This patch adds the `MLIRContext::executeCriticalSection` and `Pass::getOpDependentDialects` methods. The `MLIRContext::executeCriticalSection` method allows executing critical sections with respect to the `MLIRContext`. This method is required by `Pass::getOpDependentDialects`.
The `getOpDependentDialects` allows loading dependent dialects like the existing `Pass::getDependentDialects` method. However, this new method allows taking into consideration the operation being transformed.
Finally, the `DialectRegistry::empty` was added to avoid always executing the critical section in ` OpToOpPassAdaptor::run ` that loads the dialects.
---
Full diff: https://github.com/llvm/llvm-project/pull/98953.diff
7 Files Affected:
- (modified) mlir/include/mlir/IR/DialectRegistry.h (+3)
- (modified) mlir/include/mlir/IR/MLIRContext.h (+4)
- (modified) mlir/include/mlir/Pass/Pass.h (+9)
- (modified) mlir/lib/IR/MLIRContext.cpp (+21)
- (modified) mlir/lib/Pass/Pass.cpp (+15)
- (modified) mlir/unittests/Pass/CMakeLists.txt (+1)
- (modified) mlir/unittests/Pass/PassManagerTest.cpp (+37)
``````````diff
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index c13a1a1858eb1..901eebc2b18de 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -253,6 +253,9 @@ class DialectRegistry {
/// contains all of the components of this registry.
bool isSubsetOf(const DialectRegistry &rhs) const;
+ /// Returns true if the registry is empty.
+ bool empty() const { return registry.empty() && extensions.empty(); }
+
private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e68..558ab74512c74 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -240,6 +240,10 @@ class MLIRContext {
/// (attributes, operations, types, etc.).
llvm::hash_code getRegistryHash();
+ /// Execute a critical section guarded by the context. This method guarantees
+ /// that calling `function` is thread-safe with respect to the context.
+ void executeCriticalSection(function_ref<void()> function);
+
//===--------------------------------------------------------------------===//
// Action API
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7725a3a2910bd..50cb8d2d69632 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -70,6 +70,15 @@ class Pass {
/// register the Affine dialect but does not need to register Linalg.
virtual void getDependentDialects(DialectRegistry ®istry) const {}
+ /// Register dependent dialects for the current pass and operation being
+ /// transformed. This function is similar to `getDependentDialects` except
+ /// that it also receives the operation being transformed. When possible, use
+ /// `getDependentDialects` as this method incurs in extra synchronization
+ /// overhead. No transformations to `op` should be performed during this
+ /// method.
+ virtual void getOpDependentDialects(Operation *op,
+ DialectRegistry ®istry) const {}
+
/// Return the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const { return ""; }
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 214b354c5347e..2967d54ae3144 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -199,6 +199,9 @@ class MLIRContextImpl {
/// A mutex used when accessing operation information.
llvm::sys::SmartRWMutex<true> operationInfoMutex;
+ /// A mutex used when running critical sections.
+ llvm::sys::SmartMutex<true> criticalSectionMutex;
+
//===--------------------------------------------------------------------===//
// Affine uniquing
//===--------------------------------------------------------------------===//
@@ -703,6 +706,24 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
return RegisteredOperationName::lookup(name, this).has_value();
}
+void MLIRContext::executeCriticalSection(function_ref<void()> function) {
+ if (!function)
+ return;
+ llvm::sys::SmartScopedLock<true> lock(impl->criticalSectionMutex);
+#ifndef NDEBUG
+ // Temporarily disable `multiThreadedExecutionContext` so the context is aware
+ // only a single thread is running.
+ int multiThreadedExecutionContext = impl->multiThreadedExecutionContext;
+ impl->multiThreadedExecutionContext = 0;
+#endif
+ // Execute the critical section.
+ function();
+#ifndef NDEBUG
+ // Reset `multiThreadedExecutionContext` to its original state.
+ impl->multiThreadedExecutionContext = multiThreadedExecutionContext;
+#endif
+}
+
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
auto &impl = context->getImpl();
assert(impl.multiThreadedExecutionContext == 0 &&
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 57a6c20141d2c..73345e31da7f5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -513,6 +513,21 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
};
pass->passState.emplace(op, am, dynamicPipelineCallback);
+ // Get any op dependent dialects.
+ DialectRegistry dependentDialects;
+ pass->getOpDependentDialects(op, dependentDialects);
+ MLIRContext &context = pass->getContext();
+ // Only append a non-empty non-trivial registry.
+ if (!dependentDialects.empty() &&
+ !dependentDialects.isSubsetOf(context.getDialectRegistry())) {
+ context.executeCriticalSection([&]() {
+ // Load the dialects.
+ context.appendDialectRegistry(dependentDialects);
+ for (StringRef name : dependentDialects.getDialectNames())
+ context.getOrLoadDialect(name);
+ });
+ }
+
// Instrument before the pass has run.
if (pi)
pi->runBeforePass(pass, op);
diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index 802b3bbc6c635..0bc3238c342b9 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_unittest(MLIRPassTests
target_link_libraries(MLIRPassTests
PRIVATE
MLIRDebug
+ MLIRArithDialect
MLIRFuncDialect
MLIRPass)
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7ceed3bb3bc3b..eed98a5fe61c9 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -9,6 +9,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -93,6 +94,12 @@ struct AddAttrFunctionPass
: public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
+ void getOpDependentDialects(Operation *op,
+ DialectRegistry ®istry) const override {
+ if (op->hasAttr("didProcess"))
+ registry.insert<arith::ArithDialect>();
+ }
+
void runOnOperation() override {
func::FuncOp op = getOperation();
Builder builder(op->getParentOfType<ModuleOp>());
@@ -281,4 +288,34 @@ TEST(PassManagerTest, PassInitialization) {
EXPECT_TRUE(succeeded(pm.run(module.get())));
}
+TEST(PassManagerTest, OpDependentDialects) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder builder(&context);
+
+ // Create a module with 1 function.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto f =
+ func::FuncOp::create(builder.getUnknownLoc(), "function",
+ builder.getFunctionType(std::nullopt, std::nullopt));
+ f.setPrivate();
+ module->push_back(f);
+
+ // Instantiate and run our pass the first time.
+ {
+ auto pm = PassManager::on<ModuleOp>(&context);
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+ EXPECT_TRUE(succeeded(pm.run(module.get())));
+ }
+ ASSERT_EQ(context.getLoadedDialect<arith::ArithDialect>(), nullptr);
+
+ // Run the pass a second time, this time it should load the arith dialect.
+ {
+ auto pm = PassManager::on<ModuleOp>(&context);
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+ EXPECT_TRUE(succeeded(pm.run(module.get())));
+ }
+ ASSERT_NE(context.getLoadedDialect<arith::ArithDialect>(), nullptr);
+}
+
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/98953
More information about the Mlir-commits
mailing list