[Mlir-commits] [mlir] 013b932 - [mlir][Python] Custom python op view wrappers for building and traversing.

Stella Laurenzo llvmlistbot at llvm.org
Tue Oct 27 12:29:14 PDT 2020


Author: Stella Laurenzo
Date: 2020-10-27T12:23:34-07:00
New Revision: 013b9322dea9564ac85c4082fb0f07ff093eef63

URL: https://github.com/llvm/llvm-project/commit/013b9322dea9564ac85c4082fb0f07ff093eef63
DIFF: https://github.com/llvm/llvm-project/commit/013b9322dea9564ac85c4082fb0f07ff093eef63.diff

LOG: [mlir][Python] Custom python op view wrappers for building and traversing.

* Still rough edges that need more sugar but the bones are there. Notes left in the test case for things that can be improved.
* Does not actually yield custom OpViews yet for traversing. Will rework that in a followup.

Differential Revision: https://reviews.llvm.org/D89932

Added: 
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/mlir/dialects/__init__.py
    mlir/lib/Bindings/Python/mlir/dialects/std.py
    mlir/lib/Bindings/Python/mlir/ir.py
    mlir/test/Bindings/Python/dialects.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/mlir/__init__.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index f7b04fff7f63..d4913bb43947 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -4,6 +4,9 @@
 
 set(PY_SRC_FILES
   mlir/__init__.py
+  mlir/ir.py
+  mlir/dialects/__init__.py
+  mlir/dialects/std.py
 )
 
 add_custom_target(MLIRBindingsPythonSources ALL

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
new file mode 100644
index 000000000000..33ab4cd6722d
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -0,0 +1,94 @@
+//===- Globals.h - MLIR Python extension globals --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
+#define MLIR_BINDINGS_PYTHON_GLOBALS_H
+
+#include <string>
+#include <vector>
+
+#include "PybindUtils.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace mlir {
+namespace python {
+
+/// Globals that are always accessible once the extension has been initialized.
+class PyGlobals {
+public:
+  PyGlobals();
+  ~PyGlobals();
+
+  /// Most code should get the globals via this static accessor.
+  static PyGlobals &get() {
+    assert(instance && "PyGlobals is null");
+    return *instance;
+  }
+
+  /// Get and set the list of parent modules to search for dialect
+  /// implementation classes.
+  std::vector<std::string> &getDialectSearchPrefixes() {
+    return dialectSearchPrefixes;
+  }
+  void setDialectSearchPrefixes(std::vector<std::string> newValues) {
+    dialectSearchPrefixes.swap(newValues);
+  }
+
+  /// Loads a python module corresponding to the given dialect namespace.
+  /// No-ops if the module has already been loaded or is not found. Raises
+  /// an error on any evaluation issues.
+  /// Note that this returns void because it is expected that the module
+  /// contains calls to decorators and helpers that register the salient
+  /// entities.
+  void loadDialectModule(const std::string &dialectNamespace);
+
+  /// Decorator for registering a custom Dialect class. The class object must
+  /// have a DIALECT_NAMESPACE attribute.
+  pybind11::object registerDialectDecorator(pybind11::object pyClass);
+
+  /// Adds a concrete implementation dialect class.
+  /// Raises an exception if the mapping already exists.
+  /// This is intended to be called by implementation code.
+  void registerDialectImpl(const std::string &dialectNamespace,
+                           pybind11::object pyClass);
+
+  /// Adds a concrete implementation operation class.
+  /// Raises an exception if the mapping already exists.
+  /// This is intended to be called by implementation code.
+  void registerOperationImpl(const std::string &operationName,
+                             pybind11::object pyClass,
+                             pybind11::object rawClass);
+
+  /// Looks up a registered dialect class by namespace. Note that this may
+  /// trigger loading of the defining module and can arbitrarily re-enter.
+  llvm::Optional<pybind11::object>
+  lookupDialectClass(const std::string &dialectNamespace);
+
+private:
+  static PyGlobals *instance;
+  /// Module name prefixes to search under for dialect implementation modules.
+  std::vector<std::string> dialectSearchPrefixes;
+  /// Map of dialect namespace to bool flag indicating whether the module has
+  /// been successfully loaded or resolved to not found.
+  llvm::StringSet<> loadedDialectModules;
+  /// Map of dialect namespace to external dialect class object.
+  llvm::StringMap<pybind11::object> dialectClassMap;
+  /// Map of full operation name to external operation class object.
+  llvm::StringMap<pybind11::object> operationClassMap;
+  /// Map of operation name to custom subclass that directly initializes
+  /// the OpView base class (bypassing the user class constructor).
+  llvm::StringMap<pybind11::object> rawOperationClassMap;
+};
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 014b312971b7..2fba7fa5e283 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "IRModules.h"
+
+#include "Globals.h"
 #include "PybindUtils.h"
 
 #include "mlir-c/Bindings/Python/Interop.h"
@@ -209,19 +211,27 @@ struct PySinglePartStringAccumulator {
 } // namespace
 
 //------------------------------------------------------------------------------
-// Type-checking utilities.
+// Utilities.
 //------------------------------------------------------------------------------
 
-namespace {
-
 /// Checks whether the given type is an integer or float type.
-int mlirTypeIsAIntegerOrFloat(MlirType type) {
+static int mlirTypeIsAIntegerOrFloat(MlirType type) {
   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
 }
 
-} // namespace
+static py::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+                           py::object dialectDescriptor) {
+  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+  if (!dialectClass) {
+    // Use the base class.
+    return py::cast(PyDialect(std::move(dialectDescriptor)));
+  }
 
+  // Create the custom implementation.
+  return (*dialectClass)(std::move(dialectDescriptor));
+}
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -567,9 +577,11 @@ size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 py::object PyMlirContext::createOperation(
     std::string name, PyLocation location,
+    llvm::Optional<std::vector<PyValue *>> operands,
     llvm::Optional<std::vector<PyType *>> results,
     llvm::Optional<py::dict> attributes,
     llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
+  llvm::SmallVector<MlirValue, 4> mlirOperands;
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -578,6 +590,16 @@ py::object PyMlirContext::createOperation(
   if (regions < 0)
     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
 
+  // Unpack/validate operands.
+  if (operands) {
+    mlirOperands.reserve(operands->size());
+    for (PyValue *operand : *operands) {
+      if (!operand)
+        throw SetPyError(PyExc_ValueError, "operand value cannot be None");
+      mlirOperands.push_back(operand->get());
+    }
+  }
+
   // Unpack/validate results.
   if (results) {
     mlirResults.reserve(results->size());
@@ -614,6 +636,9 @@ py::object PyMlirContext::createOperation(
   // Apply unpacked/validated to the operation state. Beyond this
   // point, exceptions cannot be thrown or else the state will leak.
   MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
+  if (!mlirOperands.empty())
+    mlirOperationStateAddOperands(&state, mlirOperands.size(),
+                                  mlirOperands.data());
   if (!mlirResults.empty())
     mlirOperationStateAddResults(&state, mlirResults.size(),
                                  mlirResults.data());
@@ -646,6 +671,24 @@ py::object PyMlirContext::createOperation(
   return PyOperation::createDetached(getRef(), operation).releaseObject();
 }
 
+//------------------------------------------------------------------------------
+// PyDialect, PyDialectDescriptor, PyDialects
+//------------------------------------------------------------------------------
+
+MlirDialect PyDialects::getDialectForKey(const std::string &key,
+                                         bool attrError) {
+  // If the "std" dialect was asked for, substitute the empty namespace :(
+  static const std::string emptyKey;
+  const std::string *canonKey = key == "std" ? &emptyKey : &key;
+  MlirDialect dialect = mlirContextGetOrLoadDialect(
+      getContext()->get(), {canonKey->data(), canonKey->size()});
+  if (mlirDialectIsNull(dialect)) {
+    throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
+                     llvm::Twine("Dialect '") + key + "' not found");
+  }
+  return dialect;
+}
+
 //------------------------------------------------------------------------------
 // PyModule
 //------------------------------------------------------------------------------
@@ -815,6 +858,45 @@ py::object PyOperation::getAsm(bool binary,
   return fileObject.attr("getvalue")();
 }
 
+PyOpView::PyOpView(py::object operation)
+    : operationObject(std::move(operation)),
+      operation(py::cast<PyOperation *>(this->operationObject)) {}
+
+py::object PyOpView::createRawSubclass(py::object userClass) {
+  // This is... a little gross. The typical pattern is to have a pure python
+  // class that extends OpView like:
+  //   class AddFOp(_cext.ir.OpView):
+  //     def __init__(self, loc, lhs, rhs):
+  //       operation = loc.context.create_operation(
+  //           "addf", lhs, rhs, results=[lhs.type])
+  //       super().__init__(operation)
+  //
+  // I.e. The goal of the user facing type is to provide a nice constructor
+  // that has complete freedom for the op under construction. This is at odds
+  // with our other desire to sometimes create this object by just passing an
+  // operation (to initialize the base class). We could do *arg and **kwargs
+  // munging to try to make it work, but instead, we synthesize a new class
+  // on the fly which extends this user class (AddFOp in this example) and
+  // *give it* the base class's __init__ method, thus bypassing the
+  // intermediate subclass's __init__ method entirely. While slightly,
+  // underhanded, this is safe/legal because the type hierarchy has not changed
+  // (we just added a new leaf) and we aren't mucking around with __new__.
+  // Typically, this new class will be stored on the original as "_Raw" and will
+  // be used for casts and other things that need a variant of the class that
+  // is initialized purely from an operation.
+  py::object parentMetaclass =
+      py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
+  py::dict attributes;
+  // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
+  // now.
+  //   auto opViewType = py::type::of<PyOpView>();
+  auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
+  attributes["__init__"] = opViewType.attr("__init__");
+  py::str origName = userClass.attr("__name__");
+  py::str newName = py::str("_") + origName;
+  return parentMetaclass(newName, py::make_tuple(userClass), attributes);
+}
+
 //------------------------------------------------------------------------------
 // PyAttribute.
 //------------------------------------------------------------------------------
@@ -966,6 +1048,41 @@ class PyBlockArgumentList {
   MlirBlock block;
 };
 
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The result list is associated with the
+/// operation whose results these are, and extends the lifetime of this
+/// operation.
+class PyOpOperandList {
+public:
+  PyOpOperandList(PyOperationRef operation) : operation(operation) {}
+
+  /// Returns the length of the result list.
+  intptr_t dunderLen() {
+    operation->checkValid();
+    return mlirOperationGetNumOperands(operation->get());
+  }
+
+  /// Returns `index`-th element in the result list.
+  PyOpResult dunderGetItem(intptr_t index) {
+    if (index < 0 || index >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds region");
+    }
+    PyValue value(operation, mlirOperationGetOperand(operation->get(), index));
+    return PyOpResult(value);
+  }
+
+  /// Defines a Python class in the bindings.
+  static void bind(py::module &m) {
+    py::class_<PyOpOperandList>(m, "OpOperandList")
+        .def("__len__", &PyOpOperandList::dunderLen)
+        .def("__getitem__", &PyOpOperandList::dunderGetItem);
+  }
+
+private:
+  PyOperationRef operation;
+};
+
 /// A list of operation results. Internally, these are stored as consecutive
 /// elements, random access is cheap. The result list is associated with the
 /// operation whose results these are, and extends the lifetime of this
@@ -1914,7 +2031,9 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
 //------------------------------------------------------------------------------
 
 void mlir::python::populateIRSubmodule(py::module &m) {
+  //----------------------------------------------------------------------------
   // Mapping of MlirContext
+  //----------------------------------------------------------------------------
   py::class_<PyMlirContext>(m, "Context")
       .def(py::init<>(&PyMlirContext::createNewContextForInit))
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
@@ -1928,6 +2047,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def_property_readonly(
+          "dialects",
+          [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+          "Gets a container for accessing dialects by name")
+      .def_property_readonly(
+          "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+          "Alias for 'dialect'")
+      .def(
+          "get_dialect_descriptor",
+          [=](PyMlirContext &self, std::string &name) {
+            MlirDialect dialect = mlirContextGetOrLoadDialect(
+                self.get(), {name.data(), name.size()});
+            if (mlirDialectIsNull(dialect)) {
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("Dialect '") + name + "' not found");
+            }
+            return PyDialectDescriptor(self.getRef(), dialect);
+          },
+          "Gets or loads a dialect by name, returning its descriptor object")
       .def_property(
           "allow_unregistered_dialects",
           [](PyMlirContext &self) -> bool {
@@ -1937,8 +2075,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
           })
       .def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
-           py::arg("location"), py::arg("results") = py::none(),
-           py::arg("attributes") = py::none(),
+           py::arg("location"), py::arg("operands") = py::none(),
+           py::arg("results") = py::none(), py::arg("attributes") = py::none(),
            py::arg("successors") = py::none(), py::arg("regions") = 0,
            kContextCreateOperationDocstring)
       .def(
@@ -2009,6 +2147,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           kContextGetFileLocationDocstring, py::arg("filename"),
           py::arg("line"), py::arg("col"));
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialectDescriptor
+  //----------------------------------------------------------------------------
+  py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+      .def_property_readonly("namespace",
+                             [](PyDialectDescriptor &self) {
+                               MlirStringRef ns =
+                                   mlirDialectGetNamespace(self.get());
+                               return py::str(ns.data, ns.length);
+                             })
+      .def("__repr__", [](PyDialectDescriptor &self) {
+        MlirStringRef ns = mlirDialectGetNamespace(self.get());
+        std::string repr("<DialectDescriptor ");
+        repr.append(ns.data, ns.length);
+        repr.append(">");
+        return repr;
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialects
+  //----------------------------------------------------------------------------
+  py::class_<PyDialects>(m, "Dialects")
+      .def("__getitem__",
+           [=](PyDialects &self, std::string keyName) {
+             MlirDialect dialect =
+                 self.getDialectForKey(keyName, /*attrError=*/false);
+             py::object descriptor =
+                 py::cast(PyDialectDescriptor{self.getContext(), dialect});
+             return createCustomDialectWrapper(keyName, std::move(descriptor));
+           })
+      .def("__getattr__", [=](PyDialects &self, std::string attrName) {
+        MlirDialect dialect =
+            self.getDialectForKey(attrName, /*attrError=*/true);
+        py::object descriptor =
+            py::cast(PyDialectDescriptor{self.getContext(), dialect});
+        return createCustomDialectWrapper(attrName, std::move(descriptor));
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialect
+  //----------------------------------------------------------------------------
+  py::class_<PyDialect>(m, "Dialect")
+      .def(py::init<py::object>(), "descriptor")
+      .def_property_readonly(
+          "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
+      .def("__repr__", [](py::object self) {
+        auto clazz = self.attr("__class__");
+        return py::str("<Dialect ") +
+               self.attr("descriptor").attr("namespace") + py::str(" (class ") +
+               clazz.attr("__module__") + py::str(".") +
+               clazz.attr("__name__") + py::str(")>");
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of Location
+  //----------------------------------------------------------------------------
   py::class_<PyLocation>(m, "Location")
       .def_property_readonly(
           "context",
@@ -2021,7 +2215,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         return printAccum.join();
       });
 
+  //----------------------------------------------------------------------------
   // Mapping of Module
+  //----------------------------------------------------------------------------
   py::class_<PyModule>(m, "Module")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
@@ -2055,12 +2251,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           kOperationStrDunderDocstring);
 
+  //----------------------------------------------------------------------------
   // Mapping of Operation.
+  //----------------------------------------------------------------------------
   py::class_<PyOperation>(m, "Operation")
       .def_property_readonly(
           "context",
           [](PyOperation &self) { return self.getContext().getObject(); },
           "Context that owns the Operation")
+      .def_property_readonly(
+          "operands",
+          [](PyOperation &self) { return PyOpOperandList(self.getRef()); })
       .def_property_readonly(
           "regions",
           [](PyOperation &self) { return PyRegionList(self.getRef()); })
@@ -2098,7 +2299,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
 
+  py::class_<PyOpView>(m, "OpView")
+      .def(py::init<py::object>())
+      .def_property_readonly("operation", &PyOpView::getOperationObject)
+      .def("__str__",
+           [](PyOpView &self) { return py::str(self.getOperationObject()); });
+
+  //----------------------------------------------------------------------------
   // Mapping of PyRegion.
+  //----------------------------------------------------------------------------
   py::class_<PyRegion>(m, "Region")
       .def_property_readonly(
           "blocks",
@@ -2123,7 +2332,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         }
       });
 
+  //----------------------------------------------------------------------------
   // Mapping of PyBlock.
+  //----------------------------------------------------------------------------
   py::class_<PyBlock>(m, "Block")
       .def_property_readonly(
           "arguments",
@@ -2167,7 +2378,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           "Returns the assembly form of the block.");
 
+  //----------------------------------------------------------------------------
   // Mapping of PyAttribute.
+  //----------------------------------------------------------------------------
   py::class_<PyAttribute>(m, "Attribute")
       .def_property_readonly(
           "context",
@@ -2219,6 +2432,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         return printAccum.join();
       });
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyNamedAttribute
+  //----------------------------------------------------------------------------
   py::class_<PyNamedAttribute>(m, "NamedAttribute")
       .def("__repr__",
            [](PyNamedAttribute &self) {
@@ -2257,7 +2473,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyStringAttribute::bind(m);
   PyDenseElementsAttribute::bind(m);
 
+  //----------------------------------------------------------------------------
   // Mapping of PyType.
+  //----------------------------------------------------------------------------
   py::class_<PyType>(m, "Type")
       .def_property_readonly(
           "context", [](PyType &self) { return self.getContext().getObject(); },
@@ -2313,7 +2531,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyTupleType::bind(m);
   PyFunctionType::bind(m);
 
+  //----------------------------------------------------------------------------
   // Mapping of Value.
+  //----------------------------------------------------------------------------
   py::class_<PyValue>(m, "Value")
       .def_property_readonly(
           "context",
@@ -2346,6 +2566,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyBlockList::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
+  PyOpOperandList::bind(m);
   PyOpResultList::bind(m);
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index b438e8ac408d..89cca9c1c85f 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -132,6 +132,7 @@ class PyMlirContext {
   /// Creates an operation. See corresponding python docstring.
   pybind11::object
   createOperation(std::string name, PyLocation location,
+                  llvm::Optional<std::vector<PyValue *>> operands,
                   llvm::Optional<std::vector<PyType *>> results,
                   llvm::Optional<pybind11::dict> attributes,
                   llvm::Optional<std::vector<PyBlock *>> successors,
@@ -187,6 +188,45 @@ class BaseContextObject {
   PyMlirContextRef contextRef;
 };
 
+/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
+/// order to 
diff erentiate it from the `Dialect` base class which is extended by
+/// plugins which extend dialect functionality through extension python code.
+/// This should be seen as the "low-level" object and `Dialect` as the
+/// high-level, user facing object.
+class PyDialectDescriptor : public BaseContextObject {
+public:
+  PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
+      : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
+
+  MlirDialect get() { return dialect; }
+
+private:
+  MlirDialect dialect;
+};
+
+/// User-level object for accessing dialects with dotted syntax such as:
+///   ctx.dialect.std
+class PyDialects : public BaseContextObject {
+public:
+  PyDialects(PyMlirContextRef contextRef)
+      : BaseContextObject(std::move(contextRef)) {}
+
+  MlirDialect getDialectForKey(const std::string &key, bool attrError);
+};
+
+/// User-level dialect object. For dialects that have a registered extension,
+/// this will be the base class of the extension dialect type. For un-extended,
+/// objects of this type will be returned directly.
+class PyDialect {
+public:
+  PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
+
+  pybind11::object getDescriptor() { return descriptor; }
+
+private:
+  pybind11::object descriptor;
+};
+
 /// Wrapper around an MlirLocation.
 class PyLocation : public BaseContextObject {
 public:
@@ -305,6 +345,24 @@ class PyOperation : public BaseContextObject {
   bool valid = true;
 };
 
+/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
+/// providing more instance-specific accessors and serve as the base class for
+/// custom ODS-style operation classes. Since this class is subclass on the
+/// python side, it must present an __init__ method that operates in pure
+/// python types.
+class PyOpView {
+public:
+  PyOpView(pybind11::object operation);
+
+  static pybind11::object createRawSubclass(pybind11::object userClass);
+
+  pybind11::object getOperationObject() { return operationObject; }
+
+private:
+  pybind11::object operationObject; // Holds the reference.
+  PyOperation *operation;           // For efficient, cast-free access from C++
+};
+
 /// Wrapper around an MlirRegion.
 /// Regions are managed completely by their containing operation. Unlike the
 /// C++ API, the python API does not support detached regions.

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7dd525b4b340..1340468a8714 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -8,17 +8,155 @@
 
 #include <tuple>
 
-#include <pybind11/pybind11.h>
+#include "PybindUtils.h"
 
+#include "Globals.h"
 #include "IRModules.h"
 
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
 
+// -----------------------------------------------------------------------------
+// PyGlobals
+// -----------------------------------------------------------------------------
+
+PyGlobals *PyGlobals::instance = nullptr;
+
+PyGlobals::PyGlobals() {
+  assert(!instance && "PyGlobals already constructed");
+  instance = this;
+}
+
+PyGlobals::~PyGlobals() { instance = nullptr; }
+
+void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
+  if (loadedDialectModules.contains(dialectNamespace))
+    return;
+  // Since re-entrancy is possible, make a copy of the search prefixes.
+  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
+  py::object loaded;
+  for (std::string moduleName : localSearchPrefixes) {
+    moduleName.push_back('.');
+    moduleName.append(dialectNamespace);
+
+    try {
+      loaded = py::module::import(moduleName.c_str());
+    } catch (py::error_already_set &e) {
+      if (e.matches(PyExc_ModuleNotFoundError)) {
+        continue;
+      } else {
+        throw;
+      }
+    }
+    break;
+  }
+
+  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
+  // may have occurred, which may do anything.
+  loadedDialectModules.insert(dialectNamespace);
+}
+
+void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
+                                    py::object pyClass) {
+  py::object &found = dialectClassMap[dialectNamespace];
+  if (found) {
+    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
+                                             dialectNamespace +
+                                             "' is already registered.");
+  }
+  found = std::move(pyClass);
+}
+
+void PyGlobals::registerOperationImpl(const std::string &operationName,
+                                      py::object pyClass, py::object rawClass) {
+  py::object &found = operationClassMap[operationName];
+  if (found) {
+    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
+                                             operationName +
+                                             "' is already registered.");
+  }
+  found = std::move(pyClass);
+  rawOperationClassMap[operationName] = std::move(rawClass);
+}
+
+llvm::Optional<py::object>
+PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+  loadDialectModule(dialectNamespace);
+  // Fast match against the class map first (common case).
+  const auto foundIt = dialectClassMap.find(dialectNamespace);
+  if (foundIt != dialectClassMap.end()) {
+    if (foundIt->second.is_none())
+      return llvm::None;
+    assert(foundIt->second && "py::object is defined");
+    return foundIt->second;
+  }
+
+  // Not found and loading did not yield a registration. Negative cache.
+  dialectClassMap[dialectNamespace] = py::none();
+  return llvm::None;
+}
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
 PYBIND11_MODULE(_mlir, m) {
   m.doc() = "MLIR Python Native Extension";
 
+  py::class_<PyGlobals>(m, "_Globals")
+      .def_property("dialect_search_modules",
+                    &PyGlobals::getDialectSearchPrefixes,
+                    &PyGlobals::setDialectSearchPrefixes)
+      .def("append_dialect_search_prefix",
+           [](PyGlobals &self, std::string moduleName) {
+             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
+           })
+      .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
+           "Testing hook for directly registering a dialect")
+      .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
+           "Testing hook for directly registering an operation");
+
+  // Aside from making the globals accessible to python, having python manage
+  // it is necessary to make sure it is destroyed (and releases its python
+  // resources) properly.
+  m.attr("globals") =
+      py::cast(new PyGlobals, py::return_value_policy::take_ownership);
+
+  // Registration decorators.
+  m.def(
+      "register_dialect",
+      [](py::object pyClass) {
+        std::string dialectNamespace =
+            pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
+        PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
+        return pyClass;
+      },
+      "Class decorator for registering a custom Dialect wrapper");
+  m.def(
+      "register_operation",
+      [](py::object dialectClass) -> py::cpp_function {
+        return py::cpp_function(
+            [dialectClass](py::object opClass) -> py::object {
+              std::string operationName =
+                  opClass.attr("OPERATION_NAME").cast<std::string>();
+              auto rawSubclass = PyOpView::createRawSubclass(opClass);
+              PyGlobals::get().registerOperationImpl(operationName, opClass,
+                                                     rawSubclass);
+
+              // Dict-stuff the new opClass by name onto the dialect class.
+              py::object opClassName = opClass.attr("__name__");
+              dialectClass.attr(opClassName) = opClass;
+
+              // Now create a special "Raw" subclass that passes through
+              // construction to the OpView parent (bypasses the intermediate
+              // child's __init__).
+              opClass.attr("_Raw") = rawSubclass;
+              return opClass;
+            });
+      },
+      "Class decorator for registering a custom Operation wrapper");
+
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
   populateIRSubmodule(irModule);

diff  --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py
index 717526771c25..8f3b52c30f35 100644
--- a/mlir/lib/Bindings/Python/mlir/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/__init__.py
@@ -8,4 +8,37 @@
 # and arbitrate any one-time initialization needed in various shared-library
 # scenarios.
 
-from _mlir import *
+__all__ = [
+  "ir",
+]
+
+# Expose the corresponding C-Extension module with a well-known name at this
+# top-level module. This allows relative imports like the following to
+# function:
+#   from .. import _cext
+# This reduces coupling, allowing embedding of the python sources into another
+# project that can just vary based on this top-level loader module.
+import _mlir as _cext
+
+def _reexport_cext(cext_module_name, target_module_name):
+  """Re-exports a named sub-module of the C-Extension into another module.
+
+  Typically:
+    from . import _reexport_cext
+    _reexport_cext("ir", __name__)
+    del _reexport_cext
+  """
+  import sys
+  target_module = sys.modules[target_module_name]
+  source_module = getattr(_cext, cext_module_name)
+  for attr_name in dir(source_module):
+    if not attr_name.startswith("__"):
+      setattr(target_module, attr_name, getattr(source_module, attr_name))
+
+
+# Import sub-modules. Since these may import from here, this must come after
+# any exported definitions.
+from . import ir
+
+# Add our 'dialects' parent module to the search path for implementations.
+_cext.globals.append_dialect_search_prefix("mlir.dialects")

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
new file mode 100644
index 000000000000..1b7e62c030fb
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Re-export the parent _cext so that every level of the API can get it locally.
+from .. import _cext

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py
new file mode 100644
index 000000000000..2afc642e0e3d
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/std.py
@@ -0,0 +1,33 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO: This file should be auto-generated.
+
+from . import _cext
+
+ at _cext.register_dialect
+class _Dialect(_cext.ir.Dialect):
+  # Special case: 'std' namespace aliases to the empty namespace.
+  DIALECT_NAMESPACE = "std"
+  pass
+
+ at _cext.register_operation(_Dialect)
+class AddFOp(_cext.ir.OpView):
+  OPERATION_NAME = "std.addf"
+
+  def __init__(self, loc, lhs, rhs):
+    super().__init__(loc.context.create_operation(
+      "std.addf", loc, operands=[lhs, rhs], results=[lhs.type]))
+
+  @property
+  def lhs(self):
+    return self.operation.operands[0]
+
+  @property
+  def rhs(self):
+    return self.operation.operands[1]
+
+  @property
+  def result(self):
+    return self.operation.results[0]

diff  --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py
new file mode 100644
index 000000000000..70d19737f5e6
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/ir.py
@@ -0,0 +1,8 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Simply a wrapper around the extension module of the same name.
+from . import _reexport_cext
+_reexport_cext("ir", __name__)
+del _reexport_cext

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
new file mode 100644
index 000000000000..bc88e8668f4d
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -0,0 +1,107 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import mlir
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testDialectDescriptor
+def testDialectDescriptor():
+  ctx = mlir.ir.Context()
+  d = ctx.get_dialect_descriptor("std")
+  # CHECK: <DialectDescriptor std>
+  print(d)
+  # CHECK: std
+  print(d.namespace)
+  try:
+    _ = ctx.get_dialect_descriptor("not_existing")
+  except ValueError:
+    pass
+  else:
+    assert False, "Expected exception"
+
+run(testDialectDescriptor)
+
+
+# CHECK-LABEL: TEST: testUserDialectClass
+def testUserDialectClass():
+  ctx = mlir.ir.Context()
+  # Access using attribute.
+  d = ctx.dialects.std
+  # Note that the standard dialect namespace prints as ''. Others will print
+  # as "<Dialect %namespace (..."
+  # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+  print(d)
+  try:
+    _ = ctx.dialects.not_existing
+  except AttributeError:
+    pass
+  else:
+    assert False, "Expected exception"
+
+  # Access using index.
+  d = ctx.dialects["std"]
+  # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+  print(d)
+  try:
+    _ = ctx.dialects["not_existing"]
+  except IndexError:
+    pass
+  else:
+    assert False, "Expected exception"
+
+  # Using the 'd' alias.
+  d = ctx.d["std"]
+  # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+  print(d)
+
+run(testUserDialectClass)
+
+
+# CHECK-LABEL: TEST: testCustomOpView
+# This test uses the standard dialect AddFOp as an example of a user op.
+# TODO: Op creation and access is still quite verbose: simplify this test as
+# additional capabilities come online.
+def testCustomOpView():
+  ctx = mlir.ir.Context()
+  ctx.allow_unregistered_dialects = True
+  f32 = mlir.ir.F32Type.get(ctx)
+  loc = ctx.get_unknown_location()
+  m = ctx.create_module(loc)
+  m_block = m.operation.regions[0].blocks[0]
+  # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
+  ip = [0]
+  def createInput():
+    op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
+    m_block.operations.insert(ip[0], op)
+    ip[0] += 1
+    # TODO: Auto result cast from operation
+    return op.results[0]
+
+  # Create via dialects context collection.
+  input1 = createInput()
+  input2 = createInput()
+  op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
+  # TODO: Auto operation cast from OpView
+  # TODO: Context manager insertion point
+  m_block.operations.insert(ip[0], op1.operation)
+  ip[0] += 1
+
+  # Create via an import
+  from mlir.dialects.std import AddFOp
+  op2 = AddFOp(loc, input1, op1.result)
+  m_block.operations.insert(ip[0], op2.operation)
+  ip[0] += 1
+
+  # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+  # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+  # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
+  # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
+  m.operation.print()
+
+run(testCustomOpView)


        


More information about the Mlir-commits mailing list