[Mlir-commits] [mlir] f13893f - [mlir][Python] Upstream the PybindAdaptors.h helpers and use it to implement sparse_tensor.encoding.

Stella Laurenzo llvmlistbot at llvm.org
Mon May 10 10:16:00 PDT 2021


Author: Stella Laurenzo
Date: 2021-05-10T17:15:43Z
New Revision: f13893f66a228400bf9bdf14be425e3dc6da0034

URL: https://github.com/llvm/llvm-project/commit/f13893f66a228400bf9bdf14be425e3dc6da0034
DIFF: https://github.com/llvm/llvm-project/commit/f13893f66a228400bf9bdf14be425e3dc6da0034.diff

LOG: [mlir][Python] Upstream the PybindAdaptors.h helpers and use it to implement sparse_tensor.encoding.

* The PybindAdaptors.h file has been evolving across different sub-projects (npcomp, circt) and has been successfully used for out of tree python API interop/extensions and defining custom types.
* Since sparse_tensor.encoding is the first in-tree custom attribute we are supporting, it seemed like the right time to upstream this header and use it to define the attribute in a way that we can support for both in-tree and out-of-tree use (prior, I had not wanted to upstream dead code which was not used in-tree).
* Adapted the circt version of `mlir_type_subclass`, also providing an `mlir_attribute_subclass`. As we get a bit of mileage on this, I would like to transition the builtin types/attributes to this mechanism and delete the old in-tree only `PyConcreteType` and `PyConcreteAttribute` template helpers (which cannot work reliably out of tree as they depend on internals).
* Added support for defaulting the MlirContext if none is passed so that we can support the same idioms as in-tree versions.

There is quite a bit going on here and I can split it up if needed, but would prefer to keep the first use and the header together so sending out in one patch.

Differential Revision: https://reviews.llvm.org/D102144

Added: 
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/DialectSparseTensor.cpp
    mlir/lib/Bindings/Python/Dialects.h
    mlir/test/python/dialects/sparse_tensor/dialect.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/DialectLinalg.cpp
    mlir/lib/Bindings/Python/MainModule.cpp

Removed: 
    mlir/lib/Bindings/Python/DialectLinalg.h


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
new file mode 100644
index 0000000000000..db8769d3c35f3
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -0,0 +1,428 @@
+//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains adaptors for clients of the core MLIR Python APIs to
+// interop via MLIR CAPI types. The facilities here do not depend on
+// implementation details of the MLIR Python API and do not introduce C++-level
+// dependencies with it (requiring only Python and CAPI-level dependencies).
+//
+// It is encouraged to be used both in-tree and out-of-tree. For in-tree use
+// cases, it should be used for dialect implementations (versus relying on
+// Pybind-based internals of the core libraries).
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H
+#define MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H
+
+#include <pybind11/pybind11.h>
+#include <pybind11/pytypes.h>
+#include <pybind11/stl.h>
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/IR.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/Twine.h"
+
+namespace py = pybind11;
+
+// TODO: Move this to Interop.h and make it externally configurable/use it
+// consistently to locate the "import mlir" top-level.
+#define MLIR_PYTHON_PACKAGE_PREFIX "mlir."
+
+// Raw CAPI type casters need to be declared before use, so always include them
+// first.
+namespace pybind11 {
+namespace detail {
+
+template <typename T>
+struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
+
+/// Helper to convert a presumed MLIR API object to a capsule, accepting either
+/// an explicit Capsule (which can happen when two C APIs are communicating
+/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
+/// attribute (through which supported MLIR Python API objects export their
+/// contained API pointer as a capsule). This is intended to be used from
+/// type casters, which are invoked with a raw handle (unowned). The returned
+/// object's lifetime may not extend beyond the apiObject handle without
+/// explicitly having its refcount increased (i.e. on return).
+static py::object mlirApiObjectToCapsule(py::handle apiObject) {
+  if (PyCapsule_CheckExact(apiObject.ptr()))
+    return py::reinterpret_borrow<py::object>(apiObject);
+  return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
+}
+
+// Note: Currently all of the following support cast from py::object to the
+// Mlir* C-API type, but only a few light-weight, context-bound ones
+// implicitly cast the other way because the use case has not yet emerged and
+// ownership is unclear.
+
+/// Casts object <-> MlirAffineMap.
+template <>
+struct type_caster<MlirAffineMap> {
+  PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToAffineMap(capsule.ptr());
+    if (mlirAffineMapIsNull(value)) {
+      return false;
+    }
+    return !mlirAffineMapIsNull(value);
+  }
+  static handle cast(MlirAffineMap v, return_value_policy, handle) {
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("AffineMap")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  }
+};
+
+/// Casts object <-> MlirAttribute.
+template <>
+struct type_caster<MlirAttribute> {
+  PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToAttribute(capsule.ptr());
+    if (mlirAttributeIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+  static handle cast(MlirAttribute v, return_value_policy, handle) {
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("Attribute")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  }
+};
+
+/// Casts object -> MlirContext.
+template <>
+struct type_caster<MlirContext> {
+  PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
+  bool load(handle src, bool) {
+    if (src.is_none()) {
+      // Gets the current thread-bound context.
+      // TODO: This raises an error of "No current context" currently.
+      // Update the implementation to pretty-print the helpful error that the
+      // core implementations print in this case.
+      src = py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+                .attr("Context")
+                .attr("current");
+    }
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToContext(capsule.ptr());
+    if (mlirContextIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+};
+
+/// Casts object <-> MlirLocation.
+// TODO: Coerce None to default MlirLocation.
+template <>
+struct type_caster<MlirLocation> {
+  PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToLocation(capsule.ptr());
+    if (mlirLocationIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+  static handle cast(MlirLocation v, return_value_policy, handle) {
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("Location")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  }
+};
+
+/// Casts object <-> MlirModule.
+template <>
+struct type_caster<MlirModule> {
+  PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToModule(capsule.ptr());
+    if (mlirModuleIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+  static handle cast(MlirModule v, return_value_policy, handle) {
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("Module")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  };
+};
+
+/// Casts object <-> MlirOperation.
+template <>
+struct type_caster<MlirOperation> {
+  PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToOperation(capsule.ptr());
+    if (mlirOperationIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+  static handle cast(MlirOperation v, return_value_policy, handle) {
+    if (v.ptr == nullptr)
+      return py::none();
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("Operation")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  };
+};
+
+/// Casts object -> MlirPassManager.
+template <>
+struct type_caster<MlirPassManager> {
+  PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToPassManager(capsule.ptr());
+    if (mlirPassManagerIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+};
+
+/// Casts object <-> MlirType.
+template <>
+struct type_caster<MlirType> {
+  PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToType(capsule.ptr());
+    if (mlirTypeIsNull(value)) {
+      return false;
+    }
+    return true;
+  }
+  static handle cast(MlirType t, return_value_policy, handle) {
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
+    return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+        .attr("Type")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  }
+};
+
+} // namespace detail
+} // namespace pybind11
+
+namespace mlir {
+namespace python {
+namespace adaptors {
+
+/// Provides a facility like py::class_ for defining a new class in a scope,
+/// but this allows extension of an arbitrary Python class, defining methods
+/// on it is a similar way. Classes defined in this way are very similar to
+/// if defined in Python in the usual way but use Pybind11 machinery to do
+/// it. These are not "real" Pybind11 classes but pure Python classes with no
+/// relation to a concrete C++ class.
+///
+/// Derived from a discussion upstream:
+///   https://github.com/pybind/pybind11/issues/1193
+///   (plus a fair amount of extra curricular poking)
+///   TODO: If this proves useful, see about including it in pybind11.
+class pure_subclass {
+public:
+  pure_subclass(py::handle scope, const char *derivedClassName,
+                py::object superClass) {
+    py::object pyType =
+        py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
+    py::object metaclass = pyType(superClass);
+    py::dict attributes;
+
+    thisClass =
+        metaclass(derivedClassName, py::make_tuple(superClass), attributes);
+    scope.attr(derivedClassName) = thisClass;
+  }
+
+  template <typename Func, typename... Extra>
+  pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
+    py::cpp_function cf(
+        std::forward<Func>(f), py::name(name), py::is_method(py::none()),
+        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    thisClass.attr(cf.name()) = cf;
+    return *this;
+  }
+
+  template <typename Func, typename... Extra>
+  pure_subclass &def_property_readonly(const char *name, Func &&f,
+                                       const Extra &...extra) {
+    py::cpp_function cf(
+        std::forward<Func>(f), py::name(name), py::is_method(py::none()),
+        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    auto builtinProperty =
+        py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
+    thisClass.attr(name) = builtinProperty(cf);
+    return *this;
+  }
+
+  template <typename Func, typename... Extra>
+  pure_subclass &def_staticmethod(const char *name, Func &&f,
+                                  const Extra &...extra) {
+    static_assert(!std::is_member_function_pointer<Func>::value,
+                  "def_staticmethod(...) called with a non-static member "
+                  "function pointer");
+    py::cpp_function cf(
+        std::forward<Func>(f), py::name(name), py::scope(thisClass),
+        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    thisClass.attr(cf.name()) = py::staticmethod(cf);
+    return *this;
+  }
+
+  template <typename Func, typename... Extra>
+  pure_subclass &def_classmethod(const char *name, Func &&f,
+                                 const Extra &...extra) {
+    static_assert(!std::is_member_function_pointer<Func>::value,
+                  "def_classmethod(...) called with a non-static member "
+                  "function pointer");
+    py::cpp_function cf(
+        std::forward<Func>(f), py::name(name), py::scope(thisClass),
+        py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+    thisClass.attr(cf.name()) =
+        py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
+    return *this;
+  }
+
+protected:
+  py::object superClass;
+  py::object thisClass;
+};
+
+/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
+/// constructor and type checking methods.
+class mlir_attribute_subclass : public pure_subclass {
+public:
+  using IsAFunctionTy = bool (*)(MlirAttribute);
+
+  /// Subclasses by looking up the super-class dynamically.
+  mlir_attribute_subclass(py::handle scope, const char *attrClassName,
+                          IsAFunctionTy isaFunction)
+      : mlir_attribute_subclass(
+            scope, attrClassName, isaFunction,
+            py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir")
+                .attr("Attribute")) {}
+
+  /// Subclasses with a provided mlir.ir.Attribute super-class. This must
+  /// be used if the subclass is being defined in the same extension module
+  /// as the mlir.ir class (otherwise, it will trigger a recursive
+  /// initialization).
+  mlir_attribute_subclass(py::handle scope, const char *typeClassName,
+                          IsAFunctionTy isaFunction, py::object superClass)
+      : pure_subclass(scope, typeClassName, superClass) {
+    // Casting constructor. Note that defining an __init__ method is special
+    // and not yet generalized on pure_subclass (it requires a somewhat
+    // 
diff erent cpp_function and other requirements on chaining to super
+    // __init__ make it more awkward to do generally).
+    std::string captureTypeName(
+        typeClassName); // As string in case if typeClassName is not static.
+    py::cpp_function initCf(
+        [superClass, isaFunction, captureTypeName](py::object self,
+                                                   py::object otherType) {
+          MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherType);
+          if (!isaFunction(rawAttribute)) {
+            auto origRepr = py::repr(otherType).cast<std::string>();
+            throw std::invalid_argument(
+                (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
+                 " (from " + origRepr + ")")
+                    .str());
+          }
+          superClass.attr("__init__")(self, otherType);
+        },
+        py::arg("cast_from_type"), py::is_method(py::none()),
+        "Casts the passed type to this specific sub-type.");
+    thisClass.attr("__init__") = initCf;
+
+    // 'isinstance' method.
+    def_staticmethod(
+        "isinstance",
+        [isaFunction](MlirAttribute other) { return isaFunction(other); },
+        py::arg("other_attribute"));
+  }
+};
+
+/// Creates a custom subclass of mlir.ir.Type, implementing a casting
+/// constructor and type checking methods.
+class mlir_type_subclass : public pure_subclass {
+public:
+  using IsAFunctionTy = bool (*)(MlirType);
+
+  /// Subclasses by looking up the super-class dynamically.
+  mlir_type_subclass(py::handle scope, const char *typeClassName,
+                     IsAFunctionTy isaFunction)
+      : mlir_type_subclass(
+            scope, typeClassName, isaFunction,
+            py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir").attr("Type")) {}
+
+  /// Subclasses with a provided mlir.ir.Type super-class. This must
+  /// be used if the subclass is being defined in the same extension module
+  /// as the mlir.ir class (otherwise, it will trigger a recursive
+  /// initialization).
+  mlir_type_subclass(py::handle scope, const char *typeClassName,
+                     IsAFunctionTy isaFunction, py::object superClass)
+      : pure_subclass(scope, typeClassName, superClass) {
+    // Casting constructor. Note that defining an __init__ method is special
+    // and not yet generalized on pure_subclass (it requires a somewhat
+    // 
diff erent cpp_function and other requirements on chaining to super
+    // __init__ make it more awkward to do generally).
+    std::string captureTypeName(
+        typeClassName); // As string in case if typeClassName is not static.
+    py::cpp_function initCf(
+        [superClass, isaFunction, captureTypeName](py::object self,
+                                                   py::object otherType) {
+          MlirType rawType = py::cast<MlirType>(otherType);
+          if (!isaFunction(rawType)) {
+            auto origRepr = py::repr(otherType).cast<std::string>();
+            throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
+                                         captureTypeName + " (from " +
+                                         origRepr + ")")
+                                            .str());
+          }
+          superClass.attr("__init__")(self, otherType);
+        },
+        py::arg("cast_from_type"), py::is_method(py::none()),
+        "Casts the passed type to this specific sub-type.");
+    thisClass.attr("__init__") = initCf;
+
+    // 'isinstance' method.
+    def_staticmethod(
+        "isinstance",
+        [isaFunction](MlirType other) { return isaFunction(other); },
+        py::arg("other_type"));
+  }
+};
+
+} // namespace adaptors
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H

diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index a2e972dc15df4..7dc1f64b4f57e 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
     python
   SOURCES
     DialectLinalg.cpp
+    DialectSparseTensor.cpp
     MainModule.cpp
     IRAffine.cpp
     IRAttributes.cpp

diff  --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 849a0039a3ccb..dfac96db74b12 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,20 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "Dialects.h"
 #include "IRModule.h"
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
 
-#include <pybind11/pybind11.h>
+// TODO: Port this to operate only on the public PybindAdaptors.h
+#include "PybindUtils.h"
 
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
 
-namespace mlir {
-namespace python {
-
-void populateDialectLinalgSubmodule(py::module &m) {
+void mlir::python::populateDialectLinalgSubmodule(py::module m) {
   m.def(
       "fill_builtin_region",
       [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
@@ -34,6 +33,3 @@ void populateDialectLinalgSubmodule(py::module &m) {
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 }
-
-} // namespace python
-} // namespace mlir

diff  --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h
deleted file mode 100644
index 3735dbf6f6286..0000000000000
--- a/mlir/lib/Bindings/Python/DialectLinalg.h
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
-#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
-
-#include "PybindUtils.h"
-
-namespace mlir {
-namespace python {
-
-void populateDialectLinalgSubmodule(pybind11::module &m);
-
-} // namespace python
-} // namespace mlir
-
-#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H

diff  --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
new file mode 100644
index 0000000000000..faf240e1a6633
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -0,0 +1,74 @@
+//===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Dialects.h"
+#include "mlir-c/Dialect/SparseTensor.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python::adaptors;
+
+void mlir::python::populateDialectSparseTensorSubmodule(
+    py::module m, const py::module &irModule) {
+  auto attributeClass = irModule.attr("Attribute");
+
+  py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType")
+      .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
+      .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
+      .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
+
+  mlir_attribute_subclass(m, "EncodingAttr",
+                          mlirAttributeIsASparseTensorEncodingAttr,
+                          attributeClass)
+      .def_classmethod(
+          "get",
+          [](py::object cls,
+             std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
+             llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
+             int indexBitWidth, MlirContext context) {
+            return cls(mlirSparseTensorEncodingAttrGet(
+                context, dimLevelTypes.size(), dimLevelTypes.data(),
+                dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
+                pointerBitWidth, indexBitWidth));
+          },
+          py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
+          py::arg("pointer_bit_width"), py::arg("index_bit_width"),
+          py::arg("context") = py::none(),
+          "Gets a sparse_tensor.encoding from parameters.")
+      .def_property_readonly(
+          "dim_level_types",
+          [](MlirAttribute self) {
+            std::vector<MlirSparseTensorDimLevelType> ret;
+            for (int i = 0,
+                     e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
+                 i < e; ++i)
+              ret.push_back(
+                  mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
+            return ret;
+          })
+      .def_property_readonly(
+          "dim_ordering",
+          [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
+            MlirAffineMap ret =
+                mlirSparseTensorEncodingAttrGetDimOrdering(self);
+            if (mlirAffineMapIsNull(ret))
+              return {};
+            return ret;
+          })
+      .def_property_readonly(
+          "pointer_bit_width",
+          [](MlirAttribute self) {
+            return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
+          })
+      .def_property_readonly("index_bit_width", [](MlirAttribute self) {
+        return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
+      });
+}

diff  --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h
new file mode 100644
index 0000000000000..301d539275d08
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Dialects.h
@@ -0,0 +1,24 @@
+//===- Dialects.h - Declaration for dialect submodule factories -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_DIALECTS_H
+#define MLIR_BINDINGS_PYTHON_DIALECTS_H
+
+#include <pybind11/pybind11.h>
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(pybind11::module m);
+void populateDialectSparseTensorSubmodule(pybind11::module m,
+                                          const pybind11::module &irModule);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_DIALECTS_H

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 60c282d1d9d33..6e861c2f2f761 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -10,7 +10,7 @@
 
 #include "PybindUtils.h"
 
-#include "DialectLinalg.h"
+#include "Dialects.h"
 #include "ExecutionEngine.h"
 #include "Globals.h"
 #include "IRModule.h"
@@ -98,8 +98,10 @@ PYBIND11_MODULE(_mlir, m) {
       m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
   populateExecutionEngineSubmodule(executionEngineModule);
 
-  // Define and populate Linalg submodule.
+  // Define and populate dialect submodules.
   auto dialectsModule = m.def_submodule("dialects");
   auto linalgModule = dialectsModule.def_submodule("linalg");
   populateDialectLinalgSubmodule(linalgModule);
+  populateDialectSparseTensorSubmodule(
+      dialectsModule.def_submodule("sparse_tensor"), irModule);
 }

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
new file mode 100644
index 0000000000000..f10116de2033a
--- /dev/null
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -0,0 +1,76 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+# TODO: Import this into the user-package vs the cext.
+from _mlir.dialects import sparse_tensor as st
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  return f
+
+
+# CHECK-LABEL: TEST: testEncodingAttr1D
+ at run
+def testEncodingAttr1D():
+  with Context() as ctx:
+    parsed = Attribute.parse(
+      '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
+      'pointerBitWidth = 16, indexBitWidth = 32 }>')
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: dim_level_types: [<DimLevelType.compressed: 1>]
+    print(f"dim_level_types: {casted.dim_level_types}")
+    # CHECK: dim_ordering: None
+    # Note that for 1D, the ordering is None, which exercises several special
+    # cases.
+    print(f"dim_ordering: {casted.dim_ordering}")
+    # CHECK: pointer_bit_width: 16
+    print(f"pointer_bit_width: {casted.pointer_bit_width}")
+    # CHECK: index_bit_width: 32
+    print(f"index_bit_width: {casted.index_bit_width}")
+
+    created = st.EncodingAttr.get(casted.dim_level_types, None, 16, 32)
+    print(created)
+    # CHECK: created_equal: True
+    print(f"created_equal: {created == casted}")
+
+    # Verify that the factory creates an instance of the proper type.
+    # CHECK: is_proper_instance: True
+    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+    # CHECK: created_pointer_bit_width: 16
+    print(f"created_pointer_bit_width: {created.pointer_bit_width}")
+
+
+# CHECK-LABEL: TEST: testEncodingAttr2D
+ at run
+def testEncodingAttr2D():
+  with Context() as ctx:
+    parsed = Attribute.parse(
+      '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], '
+      'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, '
+      'pointerBitWidth = 16, indexBitWidth = 32 }>')
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>]
+    print(f"dim_level_types: {casted.dim_level_types}")
+    # CHECK: dim_ordering: (d0, d1) -> (d0, d1)
+    print(f"dim_ordering: {casted.dim_ordering}")
+    # CHECK: pointer_bit_width: 16
+    print(f"pointer_bit_width: {casted.pointer_bit_width}")
+    # CHECK: index_bit_width: 32
+    print(f"index_bit_width: {casted.index_bit_width}")
+
+    created = st.EncodingAttr.get(casted.dim_level_types, casted.dim_ordering,
+        16, 32)
+    print(created)
+    # CHECK: created_equal: True
+    print(f"created_equal: {created == casted}")


        


More information about the Mlir-commits mailing list