[Mlir-commits] [mlir] a7f8b7c - [mlir][python] Remove "Raw" OpView classes
Rahul Kayaith
llvmlistbot at llvm.org
Wed Mar 1 15:17:25 PST 2023
Author: Rahul Kayaith
Date: 2023-03-01T18:17:14-05:00
New Revision: a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac
URL: https://github.com/llvm/llvm-project/commit/a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac
DIFF: https://github.com/llvm/llvm-project/commit/a7f8b7cd8e49bb8680c34c1cc290a121ae37b4ac.diff
LOG: [mlir][python] Remove "Raw" OpView classes
The raw `OpView` classes are used to bypass the constructors of `OpView`
subclasses, but having a separate class can create some confusing
behaviour, e.g.:
```
op = MyOp(...)
# fails, lhs is 'MyOp', rhs is '_MyOp'
assert type(op) == type(op.operation.opview)
```
Instead we can use `__new__` to achieve the same thing without a
separate class:
```
my_op = MyOp.__new__(MyOp)
OpView.__init__(my_op, op)
```
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D143830
Added:
Modified:
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/MainModule.cpp
mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 8caa5a094a780..45d03689642ee 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -74,8 +74,7 @@ class PyGlobals {
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass,
- pybind11::object rawOpViewClass);
+ pybind11::object pyClass);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
@@ -86,10 +85,11 @@ class PyGlobals {
std::optional<pybind11::object>
lookupDialectClass(const std::string &dialectNamespace);
- /// Looks up a registered raw OpView class by operation name. Note that this
- /// may trigger a load of the dialect, which can arbitrarily re-enter.
+ /// Looks up a registered operation class (deriving from OpView) by operation
+ /// name. Note that this may trigger a load of the dialect, which can
+ /// arbitrarily re-enter.
std::optional<pybind11::object>
- lookupRawOpViewClass(llvm::StringRef operationName);
+ lookupOperationClass(llvm::StringRef operationName);
private:
static PyGlobals *instance;
@@ -99,21 +99,16 @@ class PyGlobals {
llvm::StringMap<pybind11::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
- /// Map of operation name to custom subclass that directly initializes
- /// the OpView base class (bypassing the user class constructor).
- llvm::StringMap<pybind11::object> rawOpViewClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModulesCache;
- /// Cache of operation name to custom OpView subclass that directly
- /// initializes the OpView base class (or an undefined object for negative
- /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap
- /// in order for repeat lookups of the OpView classes to only incur the cost
- /// of one hashtable lookup.
- llvm::StringMap<pybind11::object> rawOpViewClassMapCache;
+ /// Cache of operation name to external operation class object. This is
+ /// maintained on lookup as a shadow of operationClassMap in order for repeat
+ /// lookups of the classes to only incur the cost of one hashtable lookup.
+ llvm::StringMap<pybind11::object> operationClassMapCache;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12d37da5b098d..e03b6470c4dbd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1339,10 +1339,10 @@ py::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
- auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
+ auto operationCls = PyGlobals::get().lookupOperationClass(
StringRef(identStr.data, identStr.length));
- if (opViewClass)
- return (*opViewClass)(getRef().getObject());
+ if (operationCls)
+ return PyOpView::constructDerived(*operationCls, *getRef().get());
return py::cast(PyOpView(getRef().getObject()));
}
@@ -1618,47 +1618,23 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
/*regions=*/*regions, location, maybeIp);
}
+pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
+ const PyOperation &operation) {
+ // TODO: pybind11 2.6 supports a more direct form.
+ // Upgrade many years from now.
+ // auto opViewType = py::type::of<PyOpView>();
+ py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
+ py::object instance = cls.attr("__new__")(cls);
+ opViewType.attr("__init__")(instance, operation);
+ return instance;
+}
+
PyOpView::PyOpView(const py::object &operationObject)
// Casting through the PyOperationBase base-class and then back to the
// Operation lets us accept any PyOperationBase subclass.
: operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
-py::object PyOpView::createRawSubclass(const py::object &userClass) {
- // This is... a little gross. The typical pattern is to have a pure python
- // class that extends OpView like:
- // class AddFOp(_cext.ir.OpView):
- // def __init__(self, loc, lhs, rhs):
- // operation = loc.context.create_operation(
- // "addf", lhs, rhs, results=[lhs.type])
- // super().__init__(operation)
- //
- // I.e. The goal of the user facing type is to provide a nice constructor
- // that has complete freedom for the op under construction. This is at odds
- // with our other desire to sometimes create this object by just passing an
- // operation (to initialize the base class). We could do *arg and **kwargs
- // munging to try to make it work, but instead, we synthesize a new class
- // on the fly which extends this user class (AddFOp in this example) and
- // *give it* the base class's __init__ method, thus bypassing the
- // intermediate subclass's __init__ method entirely. While slightly,
- // underhanded, this is safe/legal because the type hierarchy has not changed
- // (we just added a new leaf) and we aren't mucking around with __new__.
- // Typically, this new class will be stored on the original as "_Raw" and will
- // be used for casts and other things that need a variant of the class that
- // is initialized purely from an operation.
- py::object parentMetaclass =
- py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
- py::dict attributes;
- // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
- // now.
- // auto opViewType = py::type::of<PyOpView>();
- auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
- attributes["__init__"] = opViewType.attr("__init__");
- py::str origName = userClass.attr("__name__");
- py::str newName = py::str("_") + origName;
- return parentMetaclass(newName, py::make_tuple(userClass), attributes);
-}
-
//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
@@ -2863,7 +2839,7 @@ void mlir::python::populateIRCore(py::module &m) {
throw py::value_error(
"Expected a '" + clsOpName + "' op, got: '" +
std::string(parsedOpName.data, parsedOpName.length) + "'");
- return cls.attr("_Raw")(parsed.getObject());
+ return PyOpView::constructDerived(cls, *parsed.get());
},
py::arg("cls"), py::arg("source"), py::kw_only(),
py::arg("source_name") = "", py::arg("context") = py::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e3b8ef1893940..7221442e40b99 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -84,8 +84,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass,
- py::object rawOpViewClass) {
+ py::object pyClass) {
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
@@ -93,7 +92,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
"' is already registered.");
}
found = std::move(pyClass);
- rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
std::optional<py::function>
@@ -130,10 +128,10 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
}
std::optional<pybind11::object>
-PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
+PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
{
- auto foundIt = rawOpViewClassMapCache.find(operationName);
- if (foundIt != rawOpViewClassMapCache.end()) {
+ auto foundIt = operationClassMapCache.find(operationName);
+ if (foundIt != operationClassMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
@@ -148,22 +146,22 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
// Attempt to find from the canonical map and cache.
{
- auto foundIt = rawOpViewClassMap.find(operationName);
- if (foundIt != rawOpViewClassMap.end()) {
+ auto foundIt = operationClassMap.find(operationName);
+ if (foundIt != operationClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
- rawOpViewClassMapCache[operationName] = foundIt->second;
+ operationClassMapCache[operationName] = foundIt->second;
return foundIt->second;
}
// Negative cache.
- rawOpViewClassMap[operationName] = py::none();
+ operationClassMap[operationName] = py::none();
return std::nullopt;
}
}
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
- rawOpViewClassMapCache.clear();
+ operationClassMapCache.clear();
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fa4bc1c3db1bf..4aced3639127a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -654,8 +654,6 @@ class PyOpView : public PyOperationBase {
PyOpView(const pybind11::object &operationObject);
PyOperation &getOperation() override { return operation; }
- static pybind11::object createRawSubclass(const pybind11::object &userClass);
-
pybind11::object getOperationObject() { return operationObject; }
static pybind11::object
@@ -666,6 +664,16 @@ class PyOpView : public PyOperationBase {
std::optional<int> regions, DefaultingPyLocation location,
const pybind11::object &maybeIp);
+ /// Construct an instance of a class deriving from OpView, bypassing its
+ /// `__init__` method. The derived class will typically define a constructor
+ /// that provides a convenient builder, but we need to side-step this when
+ /// constructing an `OpView` for an already-built operation.
+ ///
+ /// The caller is responsible for verifying that `operation` is a valid
+ /// operation to construct `cls` with.
+ static pybind11::object constructDerived(const pybind11::object &cls,
+ const PyOperation &operation);
+
private:
PyOperation &operation; // For efficient, cast-free access from C++
pybind11::object operationObject; // Holds the reference.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 1d6d8fa01d3bf..b32b4186fcb9f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,6 @@ PYBIND11_MODULE(_mlir, m) {
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
py::arg("operation_name"), py::arg("operation_class"),
- py::arg("raw_opview_class"),
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -68,18 +67,11 @@ PYBIND11_MODULE(_mlir, m) {
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
- auto rawSubclass = PyOpView::createRawSubclass(opClass);
- PyGlobals::get().registerOperationImpl(operationName, opClass,
- rawSubclass);
+ PyGlobals::get().registerOperationImpl(operationName, opClass);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
-
- // Now create a special "Raw" subclass that passes through
- // construction to the OpView parent (bypasses the intermediate
- // child's __init__).
- opClass.attr("_Raw") = rawSubclass;
return opClass;
});
},
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index c8734cfdef59a..93b98c4aa53fb 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -5,7 +5,7 @@ globals: "_Globals"
class _Globals:
dialect_search_modules: List[str]
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
- def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ...
+ def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
def register_dialect(dialect_class: type) -> object: ...
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bca27a680bdea..be7467d12ff13 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -620,7 +620,7 @@ def testKnownOpView():
# addf should map to a known OpView class in the arithmetic dialect.
# We know the OpView for it defines an 'lhs' attribute.
addf = module.body.operations[2]
- # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
+ # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
print(repr(addf))
# CHECK: "custom.f32"()
print(addf.lhs)
More information about the Mlir-commits
mailing list