[Mlir-commits] [mlir] a7f8b7c - [mlir][python] Remove "Raw" OpView classes

Rahul Kayaith llvmlistbot at llvm.org
Wed Mar 1 15:17:25 PST 2023


Author: Rahul Kayaith
Date: 2023-03-01T18:17:14-05:00
New Revision: a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac

URL: https://github.com/llvm/llvm-project/commit/a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac
DIFF: https://github.com/llvm/llvm-project/commit/a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac.diff

LOG: [mlir][python] Remove "Raw" OpView classes

The raw `OpView` classes are used to bypass the constructors of `OpView`
subclasses, but having a separate class can create some confusing
behaviour, e.g.:
```
op = MyOp(...)
# fails, lhs is 'MyOp', rhs is '_MyOp'
assert type(op) == type(op.operation.opview)
```

Instead we can use `__new__` to achieve the same thing without a
separate class:
```
my_op = MyOp.__new__(MyOp)
OpView.__init__(my_op, op)
```

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D143830

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 8caa5a094a780..45d03689642ee 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -74,8 +74,7 @@ class PyGlobals {
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
-                             pybind11::object pyClass,
-                             pybind11::object rawOpViewClass);
+                             pybind11::object pyClass);
 
   /// Returns the custom Attribute builder for Attribute kind.
   std::optional<pybind11::function>
@@ -86,10 +85,11 @@ class PyGlobals {
   std::optional<pybind11::object>
   lookupDialectClass(const std::string &dialectNamespace);
 
-  /// Looks up a registered raw OpView class by operation name. Note that this
-  /// may trigger a load of the dialect, which can arbitrarily re-enter.
+  /// Looks up a registered operation class (deriving from OpView) by operation
+  /// name. Note that this may trigger a load of the dialect, which can
+  /// arbitrarily re-enter.
   std::optional<pybind11::object>
-  lookupRawOpViewClass(llvm::StringRef operationName);
+  lookupOperationClass(llvm::StringRef operationName);
 
 private:
   static PyGlobals *instance;
@@ -99,21 +99,16 @@ class PyGlobals {
   llvm::StringMap<pybind11::object> dialectClassMap;
   /// Map of full operation name to external operation class object.
   llvm::StringMap<pybind11::object> operationClassMap;
-  /// Map of operation name to custom subclass that directly initializes
-  /// the OpView base class (bypassing the user class constructor).
-  llvm::StringMap<pybind11::object> rawOpViewClassMap;
   /// Map of attribute ODS name to custom builder.
   llvm::StringMap<pybind11::object> attributeBuilderMap;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
   llvm::StringSet<> loadedDialectModulesCache;
-  /// Cache of operation name to custom OpView subclass that directly
-  /// initializes the OpView base class (or an undefined object for negative
-  /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap
-  /// in order for repeat lookups of the OpView classes to only incur the cost
-  /// of one hashtable lookup.
-  llvm::StringMap<pybind11::object> rawOpViewClassMapCache;
+  /// Cache of operation name to external operation class object. This is
+  /// maintained on lookup as a shadow of operationClassMap in order for repeat
+  /// lookups of the classes to only incur the cost of one hashtable lookup.
+  llvm::StringMap<pybind11::object> operationClassMapCache;
 };
 
 } // namespace python

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12d37da5b098d..e03b6470c4dbd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1339,10 +1339,10 @@ py::object PyOperation::createOpView() {
   checkValid();
   MlirIdentifier ident = mlirOperationGetName(get());
   MlirStringRef identStr = mlirIdentifierStr(ident);
-  auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
+  auto operationCls = PyGlobals::get().lookupOperationClass(
       StringRef(identStr.data, identStr.length));
-  if (opViewClass)
-    return (*opViewClass)(getRef().getObject());
+  if (operationCls)
+    return PyOpView::constructDerived(*operationCls, *getRef().get());
   return py::cast(PyOpView(getRef().getObject()));
 }
 
@@ -1618,47 +1618,23 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
                              /*regions=*/*regions, location, maybeIp);
 }
 
+pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
+                                            const PyOperation &operation) {
+  // TODO: pybind11 2.6 supports a more direct form.
+  // Upgrade many years from now.
+  //   auto opViewType = py::type::of<PyOpView>();
+  py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
+  py::object instance = cls.attr("__new__")(cls);
+  opViewType.attr("__init__")(instance, operation);
+  return instance;
+}
+
 PyOpView::PyOpView(const py::object &operationObject)
     // Casting through the PyOperationBase base-class and then back to the
     // Operation lets us accept any PyOperationBase subclass.
     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
       operationObject(operation.getRef().getObject()) {}
 
-py::object PyOpView::createRawSubclass(const py::object &userClass) {
-  // This is... a little gross. The typical pattern is to have a pure python
-  // class that extends OpView like:
-  //   class AddFOp(_cext.ir.OpView):
-  //     def __init__(self, loc, lhs, rhs):
-  //       operation = loc.context.create_operation(
-  //           "addf", lhs, rhs, results=[lhs.type])
-  //       super().__init__(operation)
-  //
-  // I.e. The goal of the user facing type is to provide a nice constructor
-  // that has complete freedom for the op under construction. This is at odds
-  // with our other desire to sometimes create this object by just passing an
-  // operation (to initialize the base class). We could do *arg and **kwargs
-  // munging to try to make it work, but instead, we synthesize a new class
-  // on the fly which extends this user class (AddFOp in this example) and
-  // *give it* the base class's __init__ method, thus bypassing the
-  // intermediate subclass's __init__ method entirely. While slightly,
-  // underhanded, this is safe/legal because the type hierarchy has not changed
-  // (we just added a new leaf) and we aren't mucking around with __new__.
-  // Typically, this new class will be stored on the original as "_Raw" and will
-  // be used for casts and other things that need a variant of the class that
-  // is initialized purely from an operation.
-  py::object parentMetaclass =
-      py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
-  py::dict attributes;
-  // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
-  // now.
-  //   auto opViewType = py::type::of<PyOpView>();
-  auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
-  attributes["__init__"] = opViewType.attr("__init__");
-  py::str origName = userClass.attr("__name__");
-  py::str newName = py::str("_") + origName;
-  return parentMetaclass(newName, py::make_tuple(userClass), attributes);
-}
-
 //------------------------------------------------------------------------------
 // PyInsertionPoint.
 //------------------------------------------------------------------------------
@@ -2863,7 +2839,7 @@ void mlir::python::populateIRCore(py::module &m) {
           throw py::value_error(
               "Expected a '" + clsOpName + "' op, got: '" +
               std::string(parsedOpName.data, parsedOpName.length) + "'");
-        return cls.attr("_Raw")(parsed.getObject());
+        return PyOpView::constructDerived(cls, *parsed.get());
       },
       py::arg("cls"), py::arg("source"), py::kw_only(),
       py::arg("source_name") = "", py::arg("context") = py::none(),

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e3b8ef1893940..7221442e40b99 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -84,8 +84,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 }
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
-                                      py::object pyClass,
-                                      py::object rawOpViewClass) {
+                                      py::object pyClass) {
   py::object &found = operationClassMap[operationName];
   if (found) {
     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
@@ -93,7 +92,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                              "' is already registered.");
   }
   found = std::move(pyClass);
-  rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
 }
 
 std::optional<py::function>
@@ -130,10 +128,10 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
 }
 
 std::optional<pybind11::object>
-PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
+PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   {
-    auto foundIt = rawOpViewClassMapCache.find(operationName);
-    if (foundIt != rawOpViewClassMapCache.end()) {
+    auto foundIt = operationClassMapCache.find(operationName);
+    if (foundIt != operationClassMapCache.end()) {
       if (foundIt->second.is_none())
         return std::nullopt;
       assert(foundIt->second && "py::object is defined");
@@ -148,22 +146,22 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
 
   // Attempt to find from the canonical map and cache.
   {
-    auto foundIt = rawOpViewClassMap.find(operationName);
-    if (foundIt != rawOpViewClassMap.end()) {
+    auto foundIt = operationClassMap.find(operationName);
+    if (foundIt != operationClassMap.end()) {
       if (foundIt->second.is_none())
         return std::nullopt;
       assert(foundIt->second && "py::object is defined");
       // Positive cache.
-      rawOpViewClassMapCache[operationName] = foundIt->second;
+      operationClassMapCache[operationName] = foundIt->second;
       return foundIt->second;
     }
     // Negative cache.
-    rawOpViewClassMap[operationName] = py::none();
+    operationClassMap[operationName] = py::none();
     return std::nullopt;
   }
 }
 
 void PyGlobals::clearImportCache() {
   loadedDialectModulesCache.clear();
-  rawOpViewClassMapCache.clear();
+  operationClassMapCache.clear();
 }

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fa4bc1c3db1bf..4aced3639127a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -654,8 +654,6 @@ class PyOpView : public PyOperationBase {
   PyOpView(const pybind11::object &operationObject);
   PyOperation &getOperation() override { return operation; }
 
-  static pybind11::object createRawSubclass(const pybind11::object &userClass);
-
   pybind11::object getOperationObject() { return operationObject; }
 
   static pybind11::object
@@ -666,6 +664,16 @@ class PyOpView : public PyOperationBase {
                std::optional<int> regions, DefaultingPyLocation location,
                const pybind11::object &maybeIp);
 
+  /// Construct an instance of a class deriving from OpView, bypassing its
+  /// `__init__` method. The derived class will typically define a constructor
+  /// that provides a convenient builder, but we need to side-step this when
+  /// constructing an `OpView` for an already-built operation.
+  ///
+  /// The caller is responsible for verifying that `operation` is a valid
+  /// operation to construct `cls` with.
+  static pybind11::object constructDerived(const pybind11::object &cls,
+                                           const PyOperation &operation);
+
 private:
   PyOperation &operation;           // For efficient, cast-free access from C++
   pybind11::object operationObject; // Holds the reference.

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 1d6d8fa01d3bf..b32b4186fcb9f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,6 @@ PYBIND11_MODULE(_mlir, m) {
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
            py::arg("operation_name"), py::arg("operation_class"),
-           py::arg("raw_opview_class"),
            "Testing hook for directly registering an operation");
 
   // Aside from making the globals accessible to python, having python manage
@@ -68,18 +67,11 @@ PYBIND11_MODULE(_mlir, m) {
             [dialectClass](py::object opClass) -> py::object {
               std::string operationName =
                   opClass.attr("OPERATION_NAME").cast<std::string>();
-              auto rawSubclass = PyOpView::createRawSubclass(opClass);
-              PyGlobals::get().registerOperationImpl(operationName, opClass,
-                                                     rawSubclass);
+              PyGlobals::get().registerOperationImpl(operationName, opClass);
 
               // Dict-stuff the new opClass by name onto the dialect class.
               py::object opClassName = opClass.attr("__name__");
               dialectClass.attr(opClassName) = opClass;
-
-              // Now create a special "Raw" subclass that passes through
-              // construction to the OpView parent (bypasses the intermediate
-              // child's __init__).
-              opClass.attr("_Raw") = rawSubclass;
               return opClass;
             });
       },

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index c8734cfdef59a..93b98c4aa53fb 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -5,7 +5,7 @@ globals: "_Globals"
 class _Globals:
     dialect_search_modules: List[str]
     def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
-    def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ...
+    def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
     def append_dialect_search_prefix(self, module_name: str) -> None: ...
 
 def register_dialect(dialect_class: type) -> object: ...

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bca27a680bdea..be7467d12ff13 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -620,7 +620,7 @@ def testKnownOpView():
     # addf should map to a known OpView class in the arithmetic dialect.
     # We know the OpView for it defines an 'lhs' attribute.
     addf = module.body.operations[2]
-    # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
+    # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
     print(repr(addf))
     # CHECK: "custom.f32"()
     print(addf.lhs)


        


More information about the Mlir-commits mailing list