[Mlir-commits] [mlir] 8260db7 - [mlir][Python] Return and accept OpView for all functions.

Stella Laurenzo llvmlistbot at llvm.org
Tue Nov 3 22:49:06 PST 2020


Author: Stella Laurenzo
Date: 2020-11-03T22:48:34-08:00
New Revision: 8260db752c91e0c368b88607132be0a9cd9362ba

URL: https://github.com/llvm/llvm-project/commit/8260db752c91e0c368b88607132be0a9cd9362ba
DIFF: https://github.com/llvm/llvm-project/commit/8260db752c91e0c368b88607132be0a9cd9362ba.diff

LOG: [mlir][Python] Return and accept OpView for all functions.

* All functions that return an Operation now return an OpView.
* All functions that accept an Operation now accept an _OperationBase, which both Operation and OpView extend and can resolve to the backing Operation.
* Moves user-facing instance methods from Operation -> _OperationBase so that both can have the same API.
* Concretely, this means that if there are custom op classes defined (i.e. in Python), any iteration or creation will return the appropriate instance (i.e. if you get/create an std.addf, you will get an instance of the mlir.dialects.std.AddFOp class, getting full access to any custom API it exposes).
* Refactors all __eq__ methods after realizing the proper way to do this for _OperationBase.

Differential Revision: https://reviews.llvm.org/D90584

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 33ab4cd6722d..6613d2b6963c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -42,13 +42,17 @@ class PyGlobals {
     dialectSearchPrefixes.swap(newValues);
   }
 
+  /// Clears positive and negative caches regarding what implementations are
+  /// available. Future lookups will do more expensive existence checks.
+  void clearImportCache();
+
   /// Loads a python module corresponding to the given dialect namespace.
   /// No-ops if the module has already been loaded or is not found. Raises
   /// an error on any evaluation issues.
   /// Note that this returns void because it is expected that the module
   /// contains calls to decorators and helpers that register the salient
   /// entities.
-  void loadDialectModule(const std::string &dialectNamespace);
+  void loadDialectModule(llvm::StringRef dialectNamespace);
 
   /// Decorator for registering a custom Dialect class. The class object must
   /// have a DIALECT_NAMESPACE attribute.
@@ -65,27 +69,39 @@ class PyGlobals {
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
                              pybind11::object pyClass,
-                             pybind11::object rawClass);
+                             pybind11::object rawOpViewClass);
 
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   llvm::Optional<pybind11::object>
   lookupDialectClass(const std::string &dialectNamespace);
 
+  /// Looks up a registered raw OpView class by operation name. Note that this
+  /// may trigger a load of the dialect, which can arbitrarily re-enter.
+  llvm::Optional<pybind11::object>
+  lookupRawOpViewClass(llvm::StringRef operationName);
+
 private:
   static PyGlobals *instance;
   /// Module name prefixes to search under for dialect implementation modules.
   std::vector<std::string> dialectSearchPrefixes;
-  /// Map of dialect namespace to bool flag indicating whether the module has
-  /// been successfully loaded or resolved to not found.
-  llvm::StringSet<> loadedDialectModules;
   /// Map of dialect namespace to external dialect class object.
   llvm::StringMap<pybind11::object> dialectClassMap;
   /// Map of full operation name to external operation class object.
   llvm::StringMap<pybind11::object> operationClassMap;
   /// Map of operation name to custom subclass that directly initializes
   /// the OpView base class (bypassing the user class constructor).
-  llvm::StringMap<pybind11::object> rawOperationClassMap;
+  llvm::StringMap<pybind11::object> rawOpViewClassMap;
+
+  /// Set of dialect namespaces that we have attempted to import implementation
+  /// modules for.
+  llvm::StringSet<> loadedDialectModulesCache;
+  /// Cache of operation name to custom OpView subclass that directly
+  /// initializes the OpView base class (or an undefined object for negative
+  /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap
+  /// in order for repeat lookups of the OpView classes to only incur the cost
+  /// of one hashtable lookup.
+  llvm::StringMap<pybind11::object> rawOpViewClassMapCache;
 };
 
 } // namespace python

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8c17e8e6d933..a4862ee59b89 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -407,7 +407,7 @@ class PyOperationIterator {
     PyOperationRef returnOperation =
         PyOperation::forOperation(parentOperation->getContext(), next);
     next = mlirOperationGetNextInBlock(next);
-    return returnOperation.releaseObject();
+    return returnOperation->createOpView();
   }
 
   static void bind(py::module &m) {
@@ -457,7 +457,7 @@ class PyOperationList {
     while (!mlirOperationIsNull(childOp)) {
       if (index == 0) {
         return PyOperation::forOperation(parentOperation->getContext(), childOp)
-            .releaseObject();
+            ->createOpView();
       }
       childOp = mlirOperationGetNextInBlock(childOp);
       index -= 1;
@@ -868,11 +868,12 @@ void PyOperation::checkValid() {
   }
 }
 
-void PyOperation::print(py::object fileObject, bool binary,
-                        llvm::Optional<int64_t> largeElementsLimit,
-                        bool enableDebugInfo, bool prettyDebugInfo,
-                        bool printGenericOpForm, bool useLocalScope) {
-  checkValid();
+void PyOperationBase::print(py::object fileObject, bool binary,
+                            llvm::Optional<int64_t> largeElementsLimit,
+                            bool enableDebugInfo, bool prettyDebugInfo,
+                            bool printGenericOpForm, bool useLocalScope) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
   if (fileObject.is_none())
     fileObject = py::module::import("sys").attr("stdout");
   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
@@ -885,15 +886,16 @@ void PyOperation::print(py::object fileObject, bool binary,
 
   PyFileAccumulator accum(fileObject, binary);
   py::gil_scoped_release();
-  mlirOperationPrintWithFlags(get(), flags, accum.getCallback(),
+  mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
                               accum.getUserData());
   mlirOpPrintingFlagsDestroy(flags);
 }
 
-py::object PyOperation::getAsm(bool binary,
-                               llvm::Optional<int64_t> largeElementsLimit,
-                               bool enableDebugInfo, bool prettyDebugInfo,
-                               bool printGenericOpForm, bool useLocalScope) {
+py::object PyOperationBase::getAsm(bool binary,
+                                   llvm::Optional<int64_t> largeElementsLimit,
+                                   bool enableDebugInfo, bool prettyDebugInfo,
+                                   bool printGenericOpForm,
+                                   bool useLocalScope) {
   py::object fileObject;
   if (binary) {
     fileObject = py::module::import("io").attr("BytesIO")();
@@ -1034,12 +1036,24 @@ py::object PyOperation::create(
       ip->insert(*created.get());
   }
 
-  return created.releaseObject();
+  return created->createOpView();
 }
 
-PyOpView::PyOpView(py::object operation)
-    : operationObject(std::move(operation)),
-      operation(py::cast<PyOperation *>(this->operationObject)) {}
+py::object PyOperation::createOpView() {
+  MlirIdentifier ident = mlirOperationGetName(get());
+  MlirStringRef identStr = mlirIdentifierStr(ident);
+  auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
+      llvm::StringRef(identStr.data, identStr.length));
+  if (opViewClass)
+    return (*opViewClass)(getRef().getObject());
+  return py::cast(PyOpView(getRef().getObject()));
+}
+
+PyOpView::PyOpView(py::object operationObject)
+    // Casting through the PyOperationBase base-class and then back to the
+    // Operation lets us accept any PyOperationBase subclass.
+    : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
+      operationObject(operation.getRef().getObject()) {}
 
 py::object PyOpView::createRawSubclass(py::object userClass) {
   // This is... a little gross. The typical pattern is to have a pure python
@@ -1082,11 +1096,12 @@ py::object PyOpView::createRawSubclass(py::object userClass) {
 
 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
 
-PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation)
-    : block(beforeOperation.getBlock()),
-      refOperation(beforeOperation.getRef()) {}
+PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
+    : refOperation(beforeOperationBase.getOperation().getRef()),
+      block((*refOperation)->getBlock()) {}
 
-void PyInsertionPoint::insert(PyOperation &operation) {
+void PyInsertionPoint::insert(PyOperationBase &operationBase) {
+  PyOperation &operation = operationBase.getOperation();
   if (operation.isAttached())
     throw SetPyError(PyExc_ValueError,
                      "Attempt to insert operation that is already attached");
@@ -2501,33 +2516,36 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   //----------------------------------------------------------------------------
   // Mapping of Operation.
   //----------------------------------------------------------------------------
-  py::class_<PyOperation>(m, "Operation")
-      .def_static("create", &PyOperation::create, py::arg("name"),
-                  py::arg("operands") = py::none(),
-                  py::arg("results") = py::none(),
-                  py::arg("attributes") = py::none(),
-                  py::arg("successors") = py::none(), py::arg("regions") = 0,
-                  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
-                  kOperationCreateDocstring)
-      .def_property_readonly(
-          "context",
-          [](PyOperation &self) { return self.getContext().getObject(); },
-          "Context that owns the Operation")
-      .def_property_readonly(
-          "operands",
-          [](PyOperation &self) { return PyOpOperandList(self.getRef()); })
-      .def_property_readonly(
-          "regions",
-          [](PyOperation &self) { return PyRegionList(self.getRef()); })
+  py::class_<PyOperationBase>(m, "_OperationBase")
+      .def("__eq__",
+           [](PyOperationBase &self, PyOperationBase &other) {
+             return &self.getOperation() == &other.getOperation();
+           })
+      .def("__eq__",
+           [](PyOperationBase &self, py::object other) { return false; })
+      .def_property_readonly("operands",
+                             [](PyOperationBase &self) {
+                               return PyOpOperandList(
+                                   self.getOperation().getRef());
+                             })
+      .def_property_readonly("regions",
+                             [](PyOperationBase &self) {
+                               return PyRegionList(
+                                   self.getOperation().getRef());
+                             })
       .def_property_readonly(
           "results",
-          [](PyOperation &self) { return PyOpResultList(self.getRef()); },
+          [](PyOperationBase &self) {
+            return PyOpResultList(self.getOperation().getRef());
+          },
           "Returns the list of Operation results.")
       .def("__iter__",
-           [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
+           [](PyOperationBase &self) {
+             return PyRegionIterator(self.getOperation().getRef());
+           })
       .def(
           "__str__",
-          [](PyOperation &self) {
+          [](PyOperationBase &self) {
             return self.getAsm(/*binary=*/false,
                                /*largeElementsLimit=*/llvm::None,
                                /*enableDebugInfo=*/false,
@@ -2536,7 +2554,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                                /*useLocalScope=*/false);
           },
           "Returns the assembly form of the operation.")
-      .def("print", &PyOperation::print,
+      .def("print", &PyOperationBase::print,
            // Careful: Lots of arguments must match up with print method.
            py::arg("file") = py::none(), py::arg("binary") = false,
            py::arg("large_elements_limit") = py::none(),
@@ -2544,7 +2562,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false, kOperationPrintDocstring)
-      .def("get_asm", &PyOperation::getAsm,
+      .def("get_asm", &PyOperationBase::getAsm,
            // Careful: Lots of arguments must match up with get_asm method.
            py::arg("binary") = false,
            py::arg("large_elements_limit") = py::none(),
@@ -2553,9 +2571,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
 
-  py::class_<PyOpView>(m, "OpView")
+  py::class_<PyOperation, PyOperationBase>(m, "Operation")
+      .def_static("create", &PyOperation::create, py::arg("name"),
+                  py::arg("operands") = py::none(),
+                  py::arg("results") = py::none(),
+                  py::arg("attributes") = py::none(),
+                  py::arg("successors") = py::none(), py::arg("regions") = 0,
+                  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
+                  kOperationCreateDocstring)
+      .def_property_readonly(
+          "context",
+          [](PyOperation &self) { return self.getContext().getObject(); },
+          "Context that owns the Operation")
+      .def_property_readonly("opview", &PyOperation::createOpView);
+
+  py::class_<PyOpView, PyOperationBase>(m, "OpView")
       .def(py::init<py::object>())
       .def_property_readonly("operation", &PyOpView::getOperationObject)
+      .def_property_readonly(
+          "context",
+          [](PyOpView &self) {
+            return self.getOperation().getContext().getObject();
+          },
+          "Context that owns the Operation")
       .def("__str__",
            [](PyOpView &self) { return py::str(self.getOperationObject()); });
 
@@ -2577,14 +2615,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             return PyBlockIterator(self.getParentOperation(), firstBlock);
           },
           "Iterates over blocks in the region.")
-      .def("__eq__", [](PyRegion &self, py::object &other) {
-        try {
-          PyRegion *otherRegion = other.cast<PyRegion *>();
-          return self.get().ptr == otherRegion->get().ptr;
-        } catch (std::exception &e) {
-          return false;
-        }
-      });
+      .def("__eq__",
+           [](PyRegion &self, PyRegion &other) {
+             return self.get().ptr == other.get().ptr;
+           })
+      .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
 
   //----------------------------------------------------------------------------
   // Mapping of PyBlock.
@@ -2613,14 +2648,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           "Iterates over operations in the block.")
       .def("__eq__",
-           [](PyBlock &self, py::object &other) {
-             try {
-               PyBlock *otherBlock = other.cast<PyBlock *>();
-               return self.get().ptr == otherBlock->get().ptr;
-             } catch (std::exception &e) {
-               return false;
-             }
+           [](PyBlock &self, PyBlock &other) {
+             return self.get().ptr == other.get().ptr;
            })
+      .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
       .def(
           "__str__",
           [](PyBlock &self) {
@@ -2651,7 +2682,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           "Gets the InsertionPoint bound to the current thread or raises "
           "ValueError if none has been set")
-      .def(py::init<PyOperation &>(), py::arg("beforeOperation"),
+      .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
            "Inserts before a referenced operation.")
       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
                   py::arg("block"), "Inserts at the beginning of the block.")
@@ -2696,14 +2727,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           py::keep_alive<0, 1>(), "Binds a name to the attribute")
       .def("__eq__",
-           [](PyAttribute &self, py::object &other) {
-             try {
-               PyAttribute otherAttribute = other.cast<PyAttribute>();
-               return self == otherAttribute;
-             } catch (std::exception &e) {
-               return false;
-             }
-           })
+           [](PyAttribute &self, PyAttribute &other) { return self == other; })
+      .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
       .def(
           "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
           kDumpDocstring)
@@ -2793,15 +2818,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(
           "context", [](PyType &self) { return self.getContext().getObject(); },
           "Context that owns the Type")
-      .def("__eq__",
-           [](PyType &self, py::object &other) {
-             try {
-               PyType otherType = other.cast<PyType>();
-               return self == otherType;
-             } catch (std::exception &e) {
-               return false;
-             }
-           })
+      .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
+      .def("__eq__", [](PyType &self, py::object &other) { return false; })
       .def(
           "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
       .def(

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index e7fdbb9e7a5c..5236187c5b1f 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -366,6 +366,24 @@ class PyModule : public BaseContextObject {
   pybind11::handle handle;
 };
 
+/// Base class for PyOperation and PyOpView which exposes the primary, user
+/// visible methods for manipulating it.
+class PyOperationBase {
+public:
+  virtual ~PyOperationBase() = default;
+  /// Implements the bound 'print' method and helps with others.
+  void print(pybind11::object fileObject, bool binary,
+             llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+  pybind11::object getAsm(bool binary,
+                          llvm::Optional<int64_t> largeElementsLimit,
+                          bool enableDebugInfo, bool prettyDebugInfo,
+                          bool printGenericOpForm, bool useLocalScope);
+
+  /// Each must provide access to the raw Operation.
+  virtual PyOperation &getOperation() = 0;
+};
+
 /// Wrapper around PyOperation.
 /// Operations exist in either an attached (dependent) or detached (top-level)
 /// state. In the detached state (as on creation), an operation is owned by
@@ -374,9 +392,11 @@ class PyModule : public BaseContextObject {
 /// is bounded by its top-level parent reference.
 class PyOperation;
 using PyOperationRef = PyObjectRef<PyOperation>;
-class PyOperation : public BaseContextObject {
+class PyOperation : public PyOperationBase, public BaseContextObject {
 public:
   ~PyOperation();
+  PyOperation &getOperation() override { return *this; }
+
   /// Returns a PyOperation for the given MlirOperation, optionally associating
   /// it with a parentKeepAlive.
   static PyOperationRef
@@ -407,15 +427,6 @@ class PyOperation : public BaseContextObject {
   }
   void checkValid();
 
-  /// Implements the bound 'print' method and helps with others.
-  void print(pybind11::object fileObject, bool binary,
-             llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
-             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
-  pybind11::object getAsm(bool binary,
-                          llvm::Optional<int64_t> largeElementsLimit,
-                          bool enableDebugInfo, bool prettyDebugInfo,
-                          bool printGenericOpForm, bool useLocalScope);
-
   /// Gets the owning block or raises an exception if the operation has no
   /// owning block.
   PyBlock getBlock();
@@ -432,6 +443,9 @@ class PyOperation : public BaseContextObject {
          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
          DefaultingPyLocation location, pybind11::object ip);
 
+  /// Creates an OpView suitable for this operation.
+  pybind11::object createOpView();
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,
@@ -456,17 +470,18 @@ class PyOperation : public BaseContextObject {
 /// custom ODS-style operation classes. Since this class is subclass on the
 /// python side, it must present an __init__ method that operates in pure
 /// python types.
-class PyOpView {
+class PyOpView : public PyOperationBase {
 public:
-  PyOpView(pybind11::object operation);
+  PyOpView(pybind11::object operationObject);
+  PyOperation &getOperation() override { return operation; }
 
   static pybind11::object createRawSubclass(pybind11::object userClass);
 
   pybind11::object getOperationObject() { return operationObject; }
 
 private:
+  PyOperation &operation;           // For efficient, cast-free access from C++
   pybind11::object operationObject; // Holds the reference.
-  PyOperation *operation;           // For efficient, cast-free access from C++
 };
 
 /// Wrapper around an MlirRegion.
@@ -519,7 +534,7 @@ class PyInsertionPoint {
   /// block, but still inside the block.
   PyInsertionPoint(PyBlock &block);
   /// Creates an insertion point positioned before a reference operation.
-  PyInsertionPoint(PyOperation &beforeOperation);
+  PyInsertionPoint(PyOperationBase &beforeOperationBase);
 
   /// Shortcut to create an insertion point at the beginning of the block.
   static PyInsertionPoint atBlockBegin(PyBlock &block);
@@ -527,7 +542,7 @@ class PyInsertionPoint {
   static PyInsertionPoint atBlockTerminator(PyBlock &block);
 
   /// Inserts an operation.
-  void insert(PyOperation &operation);
+  void insert(PyOperationBase &operationBase);
 
   /// Enter and exit the context manager.
   pybind11::object contextEnter();
@@ -540,10 +555,10 @@ class PyInsertionPoint {
   // Trampoline constructor that avoids null initializing members while
   // looking up parents.
   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
-      : block(std::move(block)), refOperation(std::move(refOperation)) {}
+      : refOperation(std::move(refOperation)), block(std::move(block)) {}
 
-  PyBlock block;
   llvm::Optional<PyOperationRef> refOperation;
+  PyBlock block;
 };
 
 /// Wrapper around the generic MlirAttribute.

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 1340468a8714..b2c1bafa5d69 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -30,17 +30,19 @@ PyGlobals::PyGlobals() {
 
 PyGlobals::~PyGlobals() { instance = nullptr; }
 
-void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
-  if (loadedDialectModules.contains(dialectNamespace))
+void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
+  py::gil_scoped_acquire();
+  if (loadedDialectModulesCache.contains(dialectNamespace))
     return;
   // Since re-entrancy is possible, make a copy of the search prefixes.
   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
   py::object loaded;
   for (std::string moduleName : localSearchPrefixes) {
     moduleName.push_back('.');
-    moduleName.append(dialectNamespace);
+    moduleName.append(dialectNamespace.data(), dialectNamespace.size());
 
     try {
+      py::gil_scoped_release();
       loaded = py::module::import(moduleName.c_str());
     } catch (py::error_already_set &e) {
       if (e.matches(PyExc_ModuleNotFoundError)) {
@@ -54,11 +56,12 @@ void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
 
   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
   // may have occurred, which may do anything.
-  loadedDialectModules.insert(dialectNamespace);
+  loadedDialectModulesCache.insert(dialectNamespace);
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
+  py::gil_scoped_acquire();
   py::object &found = dialectClassMap[dialectNamespace];
   if (found) {
     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
@@ -69,7 +72,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 }
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
-                                      py::object pyClass, py::object rawClass) {
+                                      py::object pyClass,
+                                      py::object rawOpViewClass) {
+  py::gil_scoped_acquire();
   py::object &found = operationClassMap[operationName];
   if (found) {
     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
@@ -77,11 +82,12 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                              "' is already registered.");
   }
   found = std::move(pyClass);
-  rawOperationClassMap[operationName] = std::move(rawClass);
+  rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
 }
 
 llvm::Optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+  py::gil_scoped_acquire();
   loadDialectModule(dialectNamespace);
   // Fast match against the class map first (common case).
   const auto foundIt = dialectClassMap.find(dialectNamespace);
@@ -97,6 +103,49 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   return llvm::None;
 }
 
+llvm::Optional<pybind11::object>
+PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
+  {
+    py::gil_scoped_acquire();
+    auto foundIt = rawOpViewClassMapCache.find(operationName);
+    if (foundIt != rawOpViewClassMapCache.end()) {
+      if (foundIt->second.is_none())
+        return llvm::None;
+      assert(foundIt->second && "py::object is defined");
+      return foundIt->second;
+    }
+  }
+
+  // Not found. Load the dialect namespace.
+  auto split = operationName.split('.');
+  llvm::StringRef dialectNamespace = split.first;
+  loadDialectModule(dialectNamespace);
+
+  // Attempt to find from the canonical map and cache.
+  {
+    py::gil_scoped_acquire();
+    auto foundIt = rawOpViewClassMap.find(operationName);
+    if (foundIt != rawOpViewClassMap.end()) {
+      if (foundIt->second.is_none())
+        return llvm::None;
+      assert(foundIt->second && "py::object is defined");
+      // Positive cache.
+      rawOpViewClassMapCache[operationName] = foundIt->second;
+      return foundIt->second;
+    } else {
+      // Negative cache.
+      rawOpViewClassMap[operationName] = py::none();
+      return llvm::None;
+    }
+  }
+}
+
+void PyGlobals::clearImportCache() {
+  py::gil_scoped_acquire();
+  loadedDialectModulesCache.clear();
+  rawOpViewClassMapCache.clear();
+}
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
@@ -111,6 +160,7 @@ PYBIND11_MODULE(_mlir, m) {
       .def("append_dialect_search_prefix",
            [](PyGlobals &self, std::string moduleName) {
              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
+             self.clearImportCache();
            })
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
            "Testing hook for directly registering a dialect")

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9e0ba3071073..d5c5b3f121de 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -293,3 +293,34 @@ def testOperationPrint():
       pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
 
 run(testOperationPrint)
+
+
+def testKnownOpView():
+  with Context(), Location.unknown():
+    Context.current.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      %1 = "custom.f32"() : () -> f32
+      %2 = "custom.f32"() : () -> f32
+      %3 = addf %1, %2 : f32
+    """)
+    print(module)
+
+    # addf should map to a known OpView class in the std dialect.
+    # We know the OpView for it defines an 'lhs' attribute.
+    addf = module.body.operations[2]
+    # CHECK: <mlir.dialects.std._AddFOp object
+    print(repr(addf))
+    # CHECK: "custom.f32"()
+    print(addf.lhs)
+
+    # One of the custom ops should resolve to the default OpView.
+    custom = module.body.operations[0]
+    # CHECK: <_mlir.ir.OpView object
+    print(repr(custom))
+
+    # Check again to make sure negative caching works.
+    custom = module.body.operations[0]
+    # CHECK: <_mlir.ir.OpView object
+    print(repr(custom))
+
+run(testKnownOpView)


        


More information about the Mlir-commits mailing list