[Mlir-commits] [mlir] f05ff4f - [mlir][python] Apply py::module_local() to all classes.

Stella Laurenzo llvmlistbot at llvm.org
Mon Aug 30 22:19:48 PDT 2021


Author: Stella Laurenzo
Date: 2021-08-30T22:18:43-07:00
New Revision: f05ff4f7570ca009f22f5e2fe8c6361f28faaa8a

URL: https://github.com/llvm/llvm-project/commit/f05ff4f7570ca009f22f5e2fe8c6361f28faaa8a
DIFF: https://github.com/llvm/llvm-project/commit/f05ff4f7570ca009f22f5e2fe8c6361f28faaa8a.diff

LOG: [mlir][python] Apply py::module_local() to all classes.

* This allows multiple MLIR-API embedding downstreams to co-exist in the same process.
* I believe this is the last thing needed to enable isolated embedding.

Differential Revision: https://reviews.llvm.org/D108605

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/Pass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 510e3f8dd7929..765ff2826d5d2 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
-  py::class_<PyExecutionEngine>(m, "ExecutionEngine")
+  py::class_<PyExecutionEngine>(m, "ExecutionEngine", py::module_local())
       .def(py::init<>([](MlirModule module, int optLevel,
                          const std::vector<std::string> &sharedLibPaths) {
              llvm::SmallVector<MlirStringRef, 4> libPaths;

diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 0a2a5666a9e47..5314badba64f0 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -97,7 +97,7 @@ class PyConcreteAffineExpr : public BaseTy {
   }
 
   static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
     cls.def(py::init<PyAffineExpr &>());
     DerivedTy::bindDerived(cls);
   }
@@ -367,7 +367,8 @@ class PyIntegerSetConstraint {
   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
 
   static void bind(py::module &m) {
-    py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
+    py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
+                                       py::module_local())
         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
   }
@@ -427,7 +428,7 @@ void mlir::python::populateIRAffine(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyAffineExpr and derived classes.
   //----------------------------------------------------------------------------
-  py::class_<PyAffineExpr>(m, "AffineExpr")
+  py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyAffineExpr::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
@@ -515,7 +516,7 @@ void mlir::python::populateIRAffine(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyAffineMap.
   //----------------------------------------------------------------------------
-  py::class_<PyAffineMap>(m, "AffineMap")
+  py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyAffineMap::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
@@ -686,7 +687,7 @@ void mlir::python::populateIRAffine(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyIntegerSet.
   //----------------------------------------------------------------------------
-  py::class_<PyIntegerSet>(m, "IntegerSet")
+  py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyIntegerSet::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 0af762d93acb0..bb4b5f4f0462c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -65,7 +65,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
     }
 
     static void bind(py::module &m) {
-      py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+      py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
+                                           py::module_local())
           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
           .def("__next__", &PyArrayAttributeIterator::dunderNext);
     }

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7add4eb7b0379..8672b772e54c9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -163,7 +163,7 @@ struct PyGlobalDebugFlag {
 
   static void bind(py::module &m) {
     // Debug flags.
-    py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+    py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
         .def_property_static("flag", &PyGlobalDebugFlag::get,
                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
   }
@@ -192,7 +192,7 @@ class PyRegionIterator {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyRegionIterator>(m, "RegionIterator")
+    py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
         .def("__iter__", &PyRegionIterator::dunderIter)
         .def("__next__", &PyRegionIterator::dunderNext);
   }
@@ -224,7 +224,7 @@ class PyRegionList {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyRegionList>(m, "RegionSequence")
+    py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
         .def("__len__", &PyRegionList::dunderLen)
         .def("__getitem__", &PyRegionList::dunderGetItem);
   }
@@ -252,7 +252,7 @@ class PyBlockIterator {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyBlockIterator>(m, "BlockIterator")
+    py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
         .def("__iter__", &PyBlockIterator::dunderIter)
         .def("__next__", &PyBlockIterator::dunderNext);
   }
@@ -317,7 +317,7 @@ class PyBlockList {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyBlockList>(m, "BlockList")
+    py::class_<PyBlockList>(m, "BlockList", py::module_local())
         .def("__getitem__", &PyBlockList::dunderGetItem)
         .def("__iter__", &PyBlockList::dunderIter)
         .def("__len__", &PyBlockList::dunderLen)
@@ -349,7 +349,7 @@ class PyOperationIterator {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyOperationIterator>(m, "OperationIterator")
+    py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
         .def("__iter__", &PyOperationIterator::dunderIter)
         .def("__next__", &PyOperationIterator::dunderNext);
   }
@@ -405,7 +405,7 @@ class PyOperationList {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyOperationList>(m, "OperationList")
+    py::class_<PyOperationList>(m, "OperationList", py::module_local())
         .def("__getitem__", &PyOperationList::dunderGetItem)
         .def("__iter__", &PyOperationList::dunderIter)
         .def("__len__", &PyOperationList::dunderLen);
@@ -1539,7 +1539,7 @@ class PyConcreteValue : public PyValue {
 
   /// Binds the Python module objects to functions of this class.
   static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
     DerivedTy::bindDerived(cls);
   }
@@ -1617,7 +1617,7 @@ class PyBlockArgumentList {
 
   /// Defines a Python class in the bindings.
   static void bind(py::module &m) {
-    py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
+    py::class_<PyBlockArgumentList>(m, "BlockArgumentList", py::module_local())
         .def("__len__", &PyBlockArgumentList::dunderLen)
         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
   }
@@ -1764,7 +1764,7 @@ class PyOpAttributeMap {
   }
 
   static void bind(py::module &m) {
-    py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+    py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
         .def("__contains__", &PyOpAttributeMap::dunderContains)
         .def("__len__", &PyOpAttributeMap::dunderLen)
         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
@@ -1787,7 +1787,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of MlirContext.
   //----------------------------------------------------------------------------
-  py::class_<PyMlirContext>(m, "Context")
+  py::class_<PyMlirContext>(m, "Context", py::module_local())
       .def(py::init<>(&PyMlirContext::createNewContextForInit))
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
       .def("_get_context_again",
@@ -1851,7 +1851,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyDialectDescriptor
   //----------------------------------------------------------------------------
-  py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+  py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
       .def_property_readonly("namespace",
                              [](PyDialectDescriptor &self) {
                                MlirStringRef ns =
@@ -1869,7 +1869,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyDialects
   //----------------------------------------------------------------------------
-  py::class_<PyDialects>(m, "Dialects")
+  py::class_<PyDialects>(m, "Dialects", py::module_local())
       .def("__getitem__",
            [=](PyDialects &self, std::string keyName) {
              MlirDialect dialect =
@@ -1889,7 +1889,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyDialect
   //----------------------------------------------------------------------------
-  py::class_<PyDialect>(m, "Dialect")
+  py::class_<PyDialect>(m, "Dialect", py::module_local())
       .def(py::init<py::object>(), "descriptor")
       .def_property_readonly(
           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
@@ -1904,7 +1904,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of Location
   //----------------------------------------------------------------------------
-  py::class_<PyLocation>(m, "Location")
+  py::class_<PyLocation>(m, "Location", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
       .def("__enter__", &PyLocation::contextEnter)
@@ -1956,7 +1956,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of Module
   //----------------------------------------------------------------------------
-  py::class_<PyModule>(m, "Module")
+  py::class_<PyModule>(m, "Module", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
       .def_static(
@@ -2025,7 +2025,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of Operation.
   //----------------------------------------------------------------------------
-  py::class_<PyOperationBase>(m, "_OperationBase")
+  py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
       .def("__eq__",
            [](PyOperationBase &self, PyOperationBase &other) {
              return &self.getOperation() == &other.getOperation();
@@ -2112,7 +2112,7 @@ void mlir::python::populateIRCore(py::module &m) {
           "Verify the operation and return true if it passes, false if it "
           "fails.");
 
-  py::class_<PyOperation, PyOperationBase>(m, "Operation")
+  py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
                   py::arg("results") = py::none(),
                   py::arg("operands") = py::none(),
@@ -2149,7 +2149,7 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_property_readonly("opview", &PyOperation::createOpView);
 
   auto opViewClass =
-      py::class_<PyOpView, PyOperationBase>(m, "OpView")
+      py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
           .def(py::init<py::object>())
           .def_property_readonly("operation", &PyOpView::getOperationObject)
           .def_property_readonly(
@@ -2174,7 +2174,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyRegion.
   //----------------------------------------------------------------------------
-  py::class_<PyRegion>(m, "Region")
+  py::class_<PyRegion>(m, "Region", py::module_local())
       .def_property_readonly(
           "blocks",
           [](PyRegion &self) {
@@ -2198,7 +2198,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyBlock.
   //----------------------------------------------------------------------------
-  py::class_<PyBlock>(m, "Block")
+  py::class_<PyBlock>(m, "Block", py::module_local())
       .def_property_readonly(
           "owner",
           [](PyBlock &self) {
@@ -2288,7 +2288,7 @@ void mlir::python::populateIRCore(py::module &m) {
   // Mapping of PyInsertionPoint.
   //----------------------------------------------------------------------------
 
-  py::class_<PyInsertionPoint>(m, "InsertionPoint")
+  py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
       .def(py::init<PyBlock &>(), py::arg("block"),
            "Inserts after the last operation but still inside the block.")
       .def("__enter__", &PyInsertionPoint::contextEnter)
@@ -2318,7 +2318,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyAttribute.
   //----------------------------------------------------------------------------
-  py::class_<PyAttribute>(m, "Attribute")
+  py::class_<PyAttribute>(m, "Attribute", py::module_local())
       // Delegate to the PyAttribute copy constructor, which will also lifetime
       // extend the backing context which owns the MlirAttribute.
       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
@@ -2389,7 +2389,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyNamedAttribute
   //----------------------------------------------------------------------------
-  py::class_<PyNamedAttribute>(m, "NamedAttribute")
+  py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
       .def("__repr__",
            [](PyNamedAttribute &self) {
              PyPrintAccumulator printAccum;
@@ -2425,7 +2425,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of PyType.
   //----------------------------------------------------------------------------
-  py::class_<PyType>(m, "Type")
+  py::class_<PyType>(m, "Type", py::module_local())
       // Delegate to the PyType copy constructor, which will also lifetime
       // extend the backing context which owns the MlirType.
       .def(py::init<PyType &>(), py::arg("cast_from_type"),
@@ -2479,7 +2479,7 @@ void mlir::python::populateIRCore(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of Value.
   //----------------------------------------------------------------------------
-  py::class_<PyValue>(m, "Value")
+  py::class_<PyValue>(m, "Value", py::module_local())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
       .def_property_readonly(

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 073ac90375a30..cbade532e8ed0 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -26,7 +26,7 @@ using namespace mlir::python;
 PYBIND11_MODULE(_mlir, m) {
   m.doc() = "MLIR Python Native Extension";
 
-  py::class_<PyGlobals>(m, "_Globals")
+  py::class_<PyGlobals>(m, "_Globals", py::module_local())
       .def_property("dialect_search_modules",
                     &PyGlobals::getDialectSearchPrefixes,
                     &PyGlobals::setDialectSearchPrefixes)

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index f2433573bd0c2..6aa1c651c9a06 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -55,7 +55,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
-  py::class_<PyPassManager>(m, "PassManager")
+  py::class_<PyPassManager>(m, "PassManager", py::module_local())
       .def(py::init<>([](DefaultingPyMlirContext context) {
              MlirPassManager passManager =
                  mlirPassManagerCreate(context->get());


        


More information about the Mlir-commits mailing list