[Mlir-commits] [mlir] 436c6c9 - NFC: Break up the mlir python bindings into individual sources.

Stella Laurenzo llvmlistbot at llvm.org
Fri Mar 19 13:34:36 PDT 2021


Author: Stella Laurenzo
Date: 2021-03-19T13:33:51-07:00
New Revision: 436c6c9c20cc522c92a923440a5fc509c342a7db

URL: https://github.com/llvm/llvm-project/commit/436c6c9c20cc522c92a923440a5fc509c342a7db
DIFF: https://github.com/llvm/llvm-project/commit/436c6c9c20cc522c92a923440a5fc509c342a7db.diff

LOG: NFC: Break up the mlir python bindings into individual sources.

* IRModules.cpp -> (IRCore.cpp, IRAffine.cpp, IRAttributes.cpp, IRTypes.cpp).
* The individual pieces now compile in the 5-15s range whereas IRModules.cpp was starting to approach a minute (didn't capture a before time).
* More fine grained splitting is possible, but this represents the most obvious.

Differential Revision: https://reviews.llvm.org/D98978

Added: 
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/IRTypes.cpp

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/ExecutionEngine.cpp
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/Pass.cpp

Removed: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h


################################################################################
diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 5f042ec57c29..5fefa80398c7 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -70,7 +70,10 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
     python
   SOURCES
     MainModule.cpp
-    IRModules.cpp
+    IRAffine.cpp
+    IRAttributes.cpp
+    IRCore.cpp
+    IRTypes.cpp
     PybindUtils.cpp
     Pass.cpp
     ExecutionEngine.cpp

diff  --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
index f6f52e2e0aae..5ca9b1f68128 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
@@ -8,7 +8,7 @@
 
 #include "ExecutionEngine.h"
 
-#include "IRModules.h"
+#include "IRModule.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/ExecutionEngine.h"
 

diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
new file mode 100644
index 000000000000..73a57d95e158
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -0,0 +1,781 @@
+//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/IntegerSet.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+static const char kDumpDocstring[] =
+    R"(Dumps a debug representation of the object to stderr.)";
+
+/// Attempts to populate `result` with the content of `list` casted to the
+/// appropriate type (Python and C types are provided as template arguments).
+/// Throws errors in case of failure, using "action" to describe what the caller
+/// was attempting to do.
+template <typename PyType, typename CType>
+static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
+                           StringRef action) {
+  result.reserve(py::len(list));
+  for (py::handle item : list) {
+    try {
+      result.push_back(item.cast<PyType>());
+    } catch (py::cast_error &err) {
+      std::string msg = (llvm::Twine("Invalid expression when ") + action +
+                         " (" + err.what() + ")")
+                            .str();
+      throw py::cast_error(msg);
+    } catch (py::reference_cast_error &err) {
+      std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
+                         action + " (" + err.what() + ")")
+                            .str();
+      throw py::cast_error(msg);
+    }
+  }
+}
+
+template <typename PermutationTy>
+static bool isPermutation(std::vector<PermutationTy> permutation) {
+  llvm::SmallVector<bool, 8> seen(permutation.size(), false);
+  for (auto val : permutation) {
+    if (val < permutation.size()) {
+      if (seen[val])
+        return false;
+      seen[val] = true;
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+namespace {
+
+/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
+/// and should be castable from it. Intermediate hierarchy classes can be
+/// modeled by specifying BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAffineExpr>
+class PyConcreteAffineExpr : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = py::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirAffineExpr);
+
+  PyConcreteAffineExpr() = default;
+  PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+      : BaseTy(std::move(contextRef), affineExpr) {}
+  PyConcreteAffineExpr(PyAffineExpr &orig)
+      : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
+
+  static MlirAffineExpr castFrom(PyAffineExpr &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError,
+                       Twine("Cannot cast affine expression to ") +
+                           DerivedTy::pyClassName + " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(py::init<PyAffineExpr &>());
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
+  static constexpr const char *pyClassName = "AffineConstantExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineConstantExpr get(intptr_t value,
+                                  DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr =
+        mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
+    return PyAffineConstantExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+      return mlirAffineConstantExprGetValue(self);
+    });
+  }
+};
+
+class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
+  static constexpr const char *pyClassName = "AffineDimExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
+    return PyAffineDimExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+      return mlirAffineDimExprGetPosition(self);
+    });
+  }
+};
+
+class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
+  static constexpr const char *pyClassName = "AffineSymbolExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
+    return PyAffineSymbolExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+      return mlirAffineSymbolExprGetPosition(self);
+    });
+  }
+};
+
+class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
+  static constexpr const char *pyClassName = "AffineBinaryExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  PyAffineExpr lhs() {
+    MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
+    return PyAffineExpr(getContext(), lhsExpr);
+  }
+
+  PyAffineExpr rhs() {
+    MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
+    return PyAffineExpr(getContext(), rhsExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
+    c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+  }
+};
+
+class PyAffineAddExpr
+    : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
+  static constexpr const char *pyClassName = "AffineAddExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
+    return PyAffineAddExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineAddExpr::get);
+  }
+};
+
+class PyAffineMulExpr
+    : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
+  static constexpr const char *pyClassName = "AffineMulExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
+    return PyAffineMulExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineMulExpr::get);
+  }
+};
+
+class PyAffineModExpr
+    : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
+  static constexpr const char *pyClassName = "AffineModExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
+    return PyAffineModExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineModExpr::get);
+  }
+};
+
+class PyAffineFloorDivExpr
+    : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
+  static constexpr const char *pyClassName = "AffineFloorDivExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
+    return PyAffineFloorDivExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineFloorDivExpr::get);
+  }
+};
+
+class PyAffineCeilDivExpr
+    : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
+  static constexpr const char *pyClassName = "AffineCeilDivExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
+    return PyAffineCeilDivExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineCeilDivExpr::get);
+  }
+};
+
+} // namespace
+
+bool PyAffineExpr::operator==(const PyAffineExpr &other) {
+  return mlirAffineExprEqual(affineExpr, other.affineExpr);
+}
+
+py::object PyAffineExpr::getCapsule() {
+  return py::reinterpret_steal<py::object>(
+      mlirPythonAffineExprToCapsule(*this));
+}
+
+PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
+  MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
+  if (mlirAffineExprIsNull(rawAffineExpr))
+    throw py::error_already_set();
+  return PyAffineExpr(
+      PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
+      rawAffineExpr);
+}
+
+//------------------------------------------------------------------------------
+// PyAffineMap and utilities.
+//------------------------------------------------------------------------------
+namespace {
+
+/// A list of expressions contained in an affine map. Internally these are
+/// stored as a consecutive array leading to inexpensive random access. Both
+/// the map and the expression are owned by the context so we need not bother
+/// with lifetime extension.
+class PyAffineMapExprList
+    : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
+public:
+  static constexpr const char *pyClassName = "AffineExprList";
+
+  PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
+                      intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirAffineMapGetNumResults(map) : length,
+                  step),
+        affineMap(map) {}
+
+  intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+
+  PyAffineExpr getElement(intptr_t pos) {
+    return PyAffineExpr(affineMap.getContext(),
+                        mlirAffineMapGetResult(affineMap, pos));
+  }
+
+  PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
+                            intptr_t step) {
+    return PyAffineMapExprList(affineMap, startIndex, length, step);
+  }
+
+private:
+  PyAffineMap affineMap;
+};
+} // end namespace
+
+bool PyAffineMap::operator==(const PyAffineMap &other) {
+  return mlirAffineMapEqual(affineMap, other.affineMap);
+}
+
+py::object PyAffineMap::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
+}
+
+PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
+  MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
+  if (mlirAffineMapIsNull(rawAffineMap))
+    throw py::error_already_set();
+  return PyAffineMap(
+      PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
+      rawAffineMap);
+}
+
+//------------------------------------------------------------------------------
+// PyIntegerSet and utilities.
+//------------------------------------------------------------------------------
+namespace {
+
+class PyIntegerSetConstraint {
+public:
+  PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
+
+  PyAffineExpr getExpr() {
+    return PyAffineExpr(set.getContext(),
+                        mlirIntegerSetGetConstraint(set, pos));
+  }
+
+  bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
+
+  static void bind(py::module &m) {
+    py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
+        .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
+        .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
+  }
+
+private:
+  PyIntegerSet set;
+  intptr_t pos;
+};
+
+class PyIntegerSetConstraintList
+    : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
+public:
+  static constexpr const char *pyClassName = "IntegerSetConstraintList";
+
+  PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
+                             intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
+                  step),
+        set(set) {}
+
+  intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
+
+  PyIntegerSetConstraint getElement(intptr_t pos) {
+    return PyIntegerSetConstraint(set, pos);
+  }
+
+  PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
+                                   intptr_t step) {
+    return PyIntegerSetConstraintList(set, startIndex, length, step);
+  }
+
+private:
+  PyIntegerSet set;
+};
+} // namespace
+
+bool PyIntegerSet::operator==(const PyIntegerSet &other) {
+  return mlirIntegerSetEqual(integerSet, other.integerSet);
+}
+
+py::object PyIntegerSet::getCapsule() {
+  return py::reinterpret_steal<py::object>(
+      mlirPythonIntegerSetToCapsule(*this));
+}
+
+PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
+  MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
+  if (mlirIntegerSetIsNull(rawIntegerSet))
+    throw py::error_already_set();
+  return PyIntegerSet(
+      PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
+      rawIntegerSet);
+}
+
+void mlir::python::populateIRAffine(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of PyAffineExpr and derived classes.
+  //----------------------------------------------------------------------------
+  py::class_<PyAffineExpr>(m, "AffineExpr")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyAffineExpr::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
+      .def("__add__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineAddExpr::get(self, other);
+           })
+      .def("__mul__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineMulExpr::get(self, other);
+           })
+      .def("__mod__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineModExpr::get(self, other);
+           })
+      .def("__sub__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             auto negOne =
+                 PyAffineConstantExpr::get(-1, *self.getContext().get());
+             return PyAffineAddExpr::get(self,
+                                         PyAffineMulExpr::get(negOne, other));
+           })
+      .def("__eq__", [](PyAffineExpr &self,
+                        PyAffineExpr &other) { return self == other; })
+      .def("__eq__",
+           [](PyAffineExpr &self, py::object &other) { return false; })
+      .def("__str__",
+           [](PyAffineExpr &self) {
+             PyPrintAccumulator printAccum;
+             mlirAffineExprPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             return printAccum.join();
+           })
+      .def("__repr__",
+           [](PyAffineExpr &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("AffineExpr(");
+             mlirAffineExprPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "context",
+          [](PyAffineExpr &self) { return self.getContext().getObject(); })
+      .def_static(
+          "get_add", &PyAffineAddExpr::get,
+          "Gets an affine expression containing a sum of two expressions.")
+      .def_static(
+          "get_mul", &PyAffineMulExpr::get,
+          "Gets an affine expression containing a product of two expressions.")
+      .def_static("get_mod", &PyAffineModExpr::get,
+                  "Gets an affine expression containing the modulo of dividing "
+                  "one expression by another.")
+      .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
+                  "Gets an affine expression containing the rounded-down "
+                  "result of dividing one expression by another.")
+      .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
+                  "Gets an affine expression containing the rounded-up result "
+                  "of dividing one expression by another.")
+      .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
+                  py::arg("context") = py::none(),
+                  "Gets a constant affine expression with the given value.")
+      .def_static(
+          "get_dim", &PyAffineDimExpr::get, py::arg("position"),
+          py::arg("context") = py::none(),
+          "Gets an affine expression of a dimension at the given position.")
+      .def_static(
+          "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
+          py::arg("context") = py::none(),
+          "Gets an affine expression of a symbol at the given position.")
+      .def(
+          "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
+          kDumpDocstring);
+  PyAffineConstantExpr::bind(m);
+  PyAffineDimExpr::bind(m);
+  PyAffineSymbolExpr::bind(m);
+  PyAffineBinaryExpr::bind(m);
+  PyAffineAddExpr::bind(m);
+  PyAffineMulExpr::bind(m);
+  PyAffineModExpr::bind(m);
+  PyAffineFloorDivExpr::bind(m);
+  PyAffineCeilDivExpr::bind(m);
+
+  //----------------------------------------------------------------------------
+  // Mapping of PyAffineMap.
+  //----------------------------------------------------------------------------
+  py::class_<PyAffineMap>(m, "AffineMap")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyAffineMap::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
+      .def("__eq__",
+           [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
+      .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
+      .def("__str__",
+           [](PyAffineMap &self) {
+             PyPrintAccumulator printAccum;
+             mlirAffineMapPrint(self, printAccum.getCallback(),
+                                printAccum.getUserData());
+             return printAccum.join();
+           })
+      .def("__repr__",
+           [](PyAffineMap &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("AffineMap(");
+             mlirAffineMapPrint(self, printAccum.getCallback(),
+                                printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "context",
+          [](PyAffineMap &self) { return self.getContext().getObject(); },
+          "Context that owns the Affine Map")
+      .def(
+          "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
+          kDumpDocstring)
+      .def_static(
+          "get",
+          [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
+             DefaultingPyMlirContext context) {
+            SmallVector<MlirAffineExpr> affineExprs;
+            pyListToVector<PyAffineExpr, MlirAffineExpr>(
+                exprs, affineExprs, "attempting to create an AffineMap");
+            MlirAffineMap map =
+                mlirAffineMapGet(context->get(), dimCount, symbolCount,
+                                 affineExprs.size(), affineExprs.data());
+            return PyAffineMap(context->getRef(), map);
+          },
+          py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
+          py::arg("context") = py::none(),
+          "Gets a map with the given expressions as results.")
+      .def_static(
+          "get_constant",
+          [](intptr_t value, DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapConstantGet(context->get(), value);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("value"), py::arg("context") = py::none(),
+          "Gets an affine map with a single constant result")
+      .def_static(
+          "get_empty",
+          [](DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("context") = py::none(), "Gets an empty affine map.")
+      .def_static(
+          "get_identity",
+          [](intptr_t nDims, DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("n_dims"), py::arg("context") = py::none(),
+          "Gets an identity map with the given number of dimensions.")
+      .def_static(
+          "get_minor_identity",
+          [](intptr_t nDims, intptr_t nResults,
+             DefaultingPyMlirContext context) {
+            MlirAffineMap affineMap =
+                mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("n_dims"), py::arg("n_results"),
+          py::arg("context") = py::none(),
+          "Gets a minor identity map with the given number of dimensions and "
+          "results.")
+      .def_static(
+          "get_permutation",
+          [](std::vector<unsigned> permutation,
+             DefaultingPyMlirContext context) {
+            if (!isPermutation(permutation))
+              throw py::cast_error("Invalid permutation when attempting to "
+                                   "create an AffineMap");
+            MlirAffineMap affineMap = mlirAffineMapPermutationGet(
+                context->get(), permutation.size(), permutation.data());
+            return PyAffineMap(context->getRef(), affineMap);
+          },
+          py::arg("permutation"), py::arg("context") = py::none(),
+          "Gets an affine map that permutes its inputs.")
+      .def("get_submap",
+           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
+             intptr_t numResults = mlirAffineMapGetNumResults(self);
+             for (intptr_t pos : resultPos) {
+               if (pos < 0 || pos >= numResults)
+                 throw py::value_error("result position out of bounds");
+             }
+             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
+                 self, resultPos.size(), resultPos.data());
+             return PyAffineMap(self.getContext(), affineMap);
+           })
+      .def("get_major_submap",
+           [](PyAffineMap &self, intptr_t nResults) {
+             if (nResults >= mlirAffineMapGetNumResults(self))
+               throw py::value_error("number of results out of bounds");
+             MlirAffineMap affineMap =
+                 mlirAffineMapGetMajorSubMap(self, nResults);
+             return PyAffineMap(self.getContext(), affineMap);
+           })
+      .def("get_minor_submap",
+           [](PyAffineMap &self, intptr_t nResults) {
+             if (nResults >= mlirAffineMapGetNumResults(self))
+               throw py::value_error("number of results out of bounds");
+             MlirAffineMap affineMap =
+                 mlirAffineMapGetMinorSubMap(self, nResults);
+             return PyAffineMap(self.getContext(), affineMap);
+           })
+      .def_property_readonly(
+          "is_permutation",
+          [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
+      .def_property_readonly("is_projected_permutation",
+                             [](PyAffineMap &self) {
+                               return mlirAffineMapIsProjectedPermutation(self);
+                             })
+      .def_property_readonly(
+          "n_dims",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
+      .def_property_readonly(
+          "n_inputs",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
+      .def_property_readonly(
+          "n_symbols",
+          [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
+      .def_property_readonly("results", [](PyAffineMap &self) {
+        return PyAffineMapExprList(self);
+      });
+  PyAffineMapExprList::bind(m);
+
+  //----------------------------------------------------------------------------
+  // Mapping of PyIntegerSet.
+  //----------------------------------------------------------------------------
+  py::class_<PyIntegerSet>(m, "IntegerSet")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyIntegerSet::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
+      .def("__eq__", [](PyIntegerSet &self,
+                        PyIntegerSet &other) { return self == other; })
+      .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
+      .def("__str__",
+           [](PyIntegerSet &self) {
+             PyPrintAccumulator printAccum;
+             mlirIntegerSetPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             return printAccum.join();
+           })
+      .def("__repr__",
+           [](PyIntegerSet &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("IntegerSet(");
+             mlirIntegerSetPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "context",
+          [](PyIntegerSet &self) { return self.getContext().getObject(); })
+      .def(
+          "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
+          kDumpDocstring)
+      .def_static(
+          "get",
+          [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
+             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+            if (exprs.size() != eqFlags.size())
+              throw py::value_error(
+                  "Expected the number of constraints to match "
+                  "that of equality flags");
+            if (exprs.empty())
+              throw py::value_error("Expected non-empty list of constraints");
+
+            // Copy over to a SmallVector because std::vector has a
+            // specialization for booleans that packs data and does not
+            // expose a `bool *`.
+            SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
+
+            SmallVector<MlirAffineExpr> affineExprs;
+            pyListToVector<PyAffineExpr>(exprs, affineExprs,
+                                         "attempting to create an IntegerSet");
+            MlirIntegerSet set = mlirIntegerSetGet(
+                context->get(), numDims, numSymbols, exprs.size(),
+                affineExprs.data(), flags.data());
+            return PyIntegerSet(context->getRef(), set);
+          },
+          py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
+          py::arg("eq_flags"), py::arg("context") = py::none())
+      .def_static(
+          "get_empty",
+          [](intptr_t numDims, intptr_t numSymbols,
+             DefaultingPyMlirContext context) {
+            MlirIntegerSet set =
+                mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
+            return PyIntegerSet(context->getRef(), set);
+          },
+          py::arg("num_dims"), py::arg("num_symbols"),
+          py::arg("context") = py::none())
+      .def("get_replaced",
+           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
+              intptr_t numResultDims, intptr_t numResultSymbols) {
+             if (static_cast<intptr_t>(dimExprs.size()) !=
+                 mlirIntegerSetGetNumDims(self))
+               throw py::value_error(
+                   "Expected the number of dimension replacement expressions "
+                   "to match that of dimensions");
+             if (static_cast<intptr_t>(symbolExprs.size()) !=
+                 mlirIntegerSetGetNumSymbols(self))
+               throw py::value_error(
+                   "Expected the number of symbol replacement expressions "
+                   "to match that of symbols");
+
+             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
+             pyListToVector<PyAffineExpr>(
+                 dimExprs, dimAffineExprs,
+                 "attempting to create an IntegerSet by replacing dimensions");
+             pyListToVector<PyAffineExpr>(
+                 symbolExprs, symbolAffineExprs,
+                 "attempting to create an IntegerSet by replacing symbols");
+             MlirIntegerSet set = mlirIntegerSetReplaceGet(
+                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
+                 numResultDims, numResultSymbols);
+             return PyIntegerSet(self.getContext(), set);
+           })
+      .def_property_readonly("is_canonical_empty",
+                             [](PyIntegerSet &self) {
+                               return mlirIntegerSetIsCanonicalEmpty(self);
+                             })
+      .def_property_readonly(
+          "n_dims",
+          [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
+      .def_property_readonly(
+          "n_symbols",
+          [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
+      .def_property_readonly(
+          "n_inputs",
+          [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
+      .def_property_readonly("n_equalities",
+                             [](PyIntegerSet &self) {
+                               return mlirIntegerSetGetNumEqualities(self);
+                             })
+      .def_property_readonly("n_inequalities",
+                             [](PyIntegerSet &self) {
+                               return mlirIntegerSetGetNumInequalities(self);
+                             })
+      .def_property_readonly("constraints", [](PyIntegerSet &self) {
+        return PyIntegerSetConstraintList(self);
+      });
+  PyIntegerSetConstraint::bind(m);
+  PyIntegerSetConstraintList::bind(m);
+}

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
new file mode 100644
index 000000000000..6f9206c1b912
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -0,0 +1,761 @@
+//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+  return mlirStringRefCreate(s.data(), s.size());
+}
+
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = py::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirAttribute);
+
+  PyConcreteAttribute() = default;
+  PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+      : BaseTy(std::move(contextRef), attr) {}
+  PyConcreteAttribute(PyAttribute &orig)
+      : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
+
+  static MlirAttribute castFrom(PyAttribute &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
+    cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+  static constexpr const char *pyClassName = "AffineMapAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyAffineMap &affineMap) {
+          MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+          return PyAffineMapAttribute(affineMap.getContext(), attr);
+        },
+        py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+  }
+};
+
+class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+  static constexpr const char *pyClassName = "ArrayAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  class PyArrayAttributeIterator {
+  public:
+    PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
+
+    PyArrayAttributeIterator &dunderIter() { return *this; }
+
+    PyAttribute dunderNext() {
+      if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
+        throw py::stop_iteration();
+      }
+      return PyAttribute(attr.getContext(),
+                         mlirArrayAttrGetElement(attr.get(), nextIndex++));
+    }
+
+    static void bind(py::module &m) {
+      py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+          .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+          .def("__next__", &PyArrayAttributeIterator::dunderNext);
+    }
+
+  private:
+    PyAttribute attr;
+    int nextIndex = 0;
+  };
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](py::list attributes, DefaultingPyMlirContext context) {
+          SmallVector<MlirAttribute> mlirAttributes;
+          mlirAttributes.reserve(py::len(attributes));
+          for (auto attribute : attributes) {
+            try {
+              mlirAttributes.push_back(attribute.cast<PyAttribute>());
+            } catch (py::cast_error &err) {
+              std::string msg = std::string("Invalid attribute when attempting "
+                                            "to create an ArrayAttribute (") +
+                                err.what() + ")";
+              throw py::cast_error(msg);
+            } catch (py::reference_cast_error &err) {
+              // This exception seems thrown when the value is "None".
+              std::string msg =
+                  std::string("Invalid attribute (None?) when attempting to "
+                              "create an ArrayAttribute (") +
+                  err.what() + ")";
+              throw py::cast_error(msg);
+            }
+          }
+          MlirAttribute attr = mlirArrayAttrGet(
+              context->get(), mlirAttributes.size(), mlirAttributes.data());
+          return PyArrayAttribute(context->getRef(), attr);
+        },
+        py::arg("attributes"), py::arg("context") = py::none(),
+        "Gets a uniqued Array attribute");
+    c.def("__getitem__",
+          [](PyArrayAttribute &arr, intptr_t i) {
+            if (i >= mlirArrayAttrGetNumElements(arr))
+              throw py::index_error("ArrayAttribute index out of range");
+            return PyAttribute(arr.getContext(),
+                               mlirArrayAttrGetElement(arr, i));
+          })
+        .def("__len__",
+             [](const PyArrayAttribute &arr) {
+               return mlirArrayAttrGetNumElements(arr);
+             })
+        .def("__iter__", [](const PyArrayAttribute &arr) {
+          return PyArrayAttributeIterator(arr);
+        });
+  }
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+  static constexpr const char *pyClassName = "FloatAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &type, double value, DefaultingPyLocation loc) {
+          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirAttributeIsNull(attr)) {
+            throw SetPyError(PyExc_ValueError,
+                             Twine("invalid '") +
+                                 py::repr(py::cast(type)).cast<std::string>() +
+                                 "' and expected floating point type.");
+          }
+          return PyFloatAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
+        "Gets an uniqued float point attribute associated to a type");
+    c.def_static(
+        "get_f32",
+        [](double value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context->get(), mlirF32TypeGet(context->get()), value);
+          return PyFloatAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued float point attribute associated to a f32 type");
+    c.def_static(
+        "get_f64",
+        [](double value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context->get(), mlirF64TypeGet(context->get()), value);
+          return PyFloatAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued float point attribute associated to a f64 type");
+    c.def_property_readonly(
+        "value",
+        [](PyFloatAttribute &self) {
+          return mlirFloatAttrGetValueDouble(self);
+        },
+        "Returns the value of the float point attribute");
+  }
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &type, int64_t value) {
+          MlirAttribute attr = mlirIntegerAttrGet(type, value);
+          return PyIntegerAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"),
+        "Gets an uniqued integer attribute associated to a type");
+    c.def_property_readonly(
+        "value",
+        [](PyIntegerAttribute &self) {
+          return mlirIntegerAttrGetValueInt(self);
+        },
+        "Returns the value of the integer attribute");
+  }
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+  static constexpr const char *pyClassName = "BoolAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](bool value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+          return PyBoolAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued bool attribute");
+    c.def_property_readonly(
+        "value",
+        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
+        "Returns the value of the bool attribute");
+  }
+};
+
+class PyFlatSymbolRefAttribute
+    : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::string value, DefaultingPyMlirContext context) {
+          MlirAttribute attr =
+              mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+          return PyFlatSymbolRefAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets a uniqued FlatSymbolRef attribute");
+    c.def_property_readonly(
+        "value",
+        [](PyFlatSymbolRefAttribute &self) {
+          MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+          return py::str(stringRef.data, stringRef.length);
+        },
+        "Returns the value of the FlatSymbolRef attribute as a string");
+  }
+};
+
+class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
+  static constexpr const char *pyClassName = "StringAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::string value, DefaultingPyMlirContext context) {
+          MlirAttribute attr =
+              mlirStringAttrGet(context->get(), toMlirStringRef(value));
+          return PyStringAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets a uniqued string attribute");
+    c.def_static(
+        "get_typed",
+        [](PyType &type, std::string value) {
+          MlirAttribute attr =
+              mlirStringAttrTypedGet(type, toMlirStringRef(value));
+          return PyStringAttribute(type.getContext(), attr);
+        },
+
+        "Gets a uniqued string attribute associated to a type");
+    c.def_property_readonly(
+        "value",
+        [](PyStringAttribute &self) {
+          MlirStringRef stringRef = mlirStringAttrGetValue(self);
+          return py::str(stringRef.data, stringRef.length);
+        },
+        "Returns the value of the string attribute");
+  }
+};
+
+// TODO: Support construction of bool elements.
+// TODO: Support construction of string elements.
+class PyDenseElementsAttribute
+    : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+  static constexpr const char *pyClassName = "DenseElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PyDenseElementsAttribute
+  getFromBuffer(py::buffer array, bool signless,
+                DefaultingPyMlirContext contextWrapper) {
+    // Request a contiguous view. In exotic cases, this will cause a copy.
+    int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
+    Py_buffer *view = new Py_buffer();
+    if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
+      delete view;
+      throw py::error_already_set();
+    }
+    py::buffer_info arrayInfo(view);
+
+    MlirContext context = contextWrapper->get();
+    // Switch on the types that can be bulk loaded between the Python and
+    // MLIR-C APIs.
+    // See: https://docs.python.org/3/library/struct.html#format-characters
+    if (arrayInfo.format == "f") {
+      // f32
+      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+      return PyDenseElementsAttribute(
+          contextWrapper->getRef(),
+          bulkLoad(context, mlirDenseElementsAttrFloatGet,
+                   mlirF32TypeGet(context), arrayInfo));
+    } else if (arrayInfo.format == "d") {
+      // f64
+      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+      return PyDenseElementsAttribute(
+          contextWrapper->getRef(),
+          bulkLoad(context, mlirDenseElementsAttrDoubleGet,
+                   mlirF64TypeGet(context), arrayInfo));
+    } else if (isSignedIntegerFormat(arrayInfo.format)) {
+      if (arrayInfo.itemsize == 4) {
+        // i32
+        MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+                                        : mlirIntegerTypeSignedGet(context, 32);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrInt32Get,
+                                                 elementType, arrayInfo));
+      } else if (arrayInfo.itemsize == 8) {
+        // i64
+        MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+                                        : mlirIntegerTypeSignedGet(context, 64);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrInt64Get,
+                                                 elementType, arrayInfo));
+      }
+    } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
+      if (arrayInfo.itemsize == 4) {
+        // unsigned i32
+        MlirType elementType = signless
+                                   ? mlirIntegerTypeGet(context, 32)
+                                   : mlirIntegerTypeUnsignedGet(context, 32);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrUInt32Get,
+                                                 elementType, arrayInfo));
+      } else if (arrayInfo.itemsize == 8) {
+        // unsigned i64
+        MlirType elementType = signless
+                                   ? mlirIntegerTypeGet(context, 64)
+                                   : mlirIntegerTypeUnsignedGet(context, 64);
+        return PyDenseElementsAttribute(contextWrapper->getRef(),
+                                        bulkLoad(context,
+                                                 mlirDenseElementsAttrUInt64Get,
+                                                 elementType, arrayInfo));
+      }
+    }
+
+    // TODO: Fall back to string-based get.
+    std::string message = "unimplemented array format conversion from format: ";
+    message.append(arrayInfo.format);
+    throw SetPyError(PyExc_ValueError, message);
+  }
+
+  static PyDenseElementsAttribute getSplat(PyType shapedType,
+                                           PyAttribute &elementAttr) {
+    auto contextWrapper =
+        PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+    if (!mlirAttributeIsAInteger(elementAttr) &&
+        !mlirAttributeIsAFloat(elementAttr)) {
+      std::string message = "Illegal element type for DenseElementsAttr: ";
+      message.append(py::repr(py::cast(elementAttr)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+    if (!mlirTypeIsAShaped(shapedType) ||
+        !mlirShapedTypeHasStaticShape(shapedType)) {
+      std::string message =
+          "Expected a static ShapedType for the shaped_type parameter: ";
+      message.append(py::repr(py::cast(shapedType)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+    MlirType attrType = mlirAttributeGetType(elementAttr);
+    if (!mlirTypeEqual(shapedElementType, attrType)) {
+      std::string message =
+          "Shaped element type and attribute type must be equal: shaped=";
+      message.append(py::repr(py::cast(shapedType)));
+      message.append(", element=");
+      message.append(py::repr(py::cast(elementAttr)));
+      throw SetPyError(PyExc_ValueError, message);
+    }
+
+    MlirAttribute elements =
+        mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+  }
+
+  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
+
+  py::buffer_info accessBuffer() {
+    MlirType shapedType = mlirAttributeGetType(*this);
+    MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+
+    if (mlirTypeIsAF32(elementType)) {
+      // f32
+      return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+    } else if (mlirTypeIsAF64(elementType)) {
+      // f64
+      return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 32) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i32
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i32
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 64) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i64
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i64
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+      }
+    }
+
+    std::string message = "unimplemented array format.";
+    throw SetPyError(PyExc_ValueError, message);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+        .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+                    py::arg("array"), py::arg("signless") = true,
+                    py::arg("context") = py::none(),
+                    "Gets from a buffer or ndarray")
+        .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+                    py::arg("shaped_type"), py::arg("element_attr"),
+                    "Gets a DenseElementsAttr where all values are the same")
+        .def_property_readonly("is_splat",
+                               [](PyDenseElementsAttribute &self) -> bool {
+                                 return mlirDenseElementsAttrIsSplat(self);
+                               })
+        .def_buffer(&PyDenseElementsAttribute::accessBuffer);
+  }
+
+private:
+  template <typename ElementTy>
+  static MlirAttribute
+  bulkLoad(MlirContext context,
+           MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
+           MlirType mlirElementType, py::buffer_info &arrayInfo) {
+    SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
+                                  arrayInfo.shape.begin() + arrayInfo.ndim);
+    auto shapedType =
+        mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
+    intptr_t numElements = arrayInfo.size;
+    const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
+    return ctor(shapedType, numElements, contents);
+  }
+
+  static bool isUnsignedIntegerFormat(const std::string &format) {
+    if (format.empty())
+      return false;
+    char code = format[0];
+    return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+           code == 'Q';
+  }
+
+  static bool isSignedIntegerFormat(const std::string &format) {
+    if (format.empty())
+      return false;
+    char code = format[0];
+    return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+           code == 'q';
+  }
+
+  template <typename Type>
+  py::buffer_info bufferInfo(MlirType shapedType,
+                             Type (*value)(MlirAttribute, intptr_t)) {
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+    // Prepare the data for the buffer_info.
+    // Buffer is configured for read-only access below.
+    Type *data = static_cast<Type *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    // Prepare the shape for the buffer_info.
+    SmallVector<intptr_t, 4> shape;
+    for (intptr_t i = 0; i < rank; ++i)
+      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+    // Prepare the strides for the buffer_info.
+    SmallVector<intptr_t, 4> strides;
+    intptr_t strideFactor = 1;
+    for (intptr_t i = 1; i < rank; ++i) {
+      strideFactor = 1;
+      for (intptr_t j = i; j < rank; ++j) {
+        strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+      }
+      strides.push_back(sizeof(Type) * strideFactor);
+    }
+    strides.push_back(sizeof(Type));
+    return py::buffer_info(data, sizeof(Type),
+                           py::format_descriptor<Type>::format(), rank, shape,
+                           strides, /*readonly=*/true);
+  }
+}; // namespace
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+    : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+  static constexpr const char *pyClassName = "DenseIntElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  /// Returns the element at the given linear position. Asserts if the index is
+  /// out of range.
+  py::int_ dunderGetItem(intptr_t pos) {
+    if (pos < 0 || pos >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds element");
+    }
+
+    MlirType type = mlirAttributeGetType(*this);
+    type = mlirShapedTypeGetElementType(type);
+    assert(mlirTypeIsAInteger(type) &&
+           "expected integer element type in dense int elements attribute");
+    // Dispatch element extraction to an appropriate C function based on the
+    // elemental type of the attribute. py::int_ is implicitly constructible
+    // from any C++ integral type and handles bitwidth correctly.
+    // TODO: consider caching the type properties in the constructor to avoid
+    // querying them on each element access.
+    unsigned width = mlirIntegerTypeGetWidth(type);
+    bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+    if (isUnsigned) {
+      if (width == 1) {
+        return mlirDenseElementsAttrGetBoolValue(*this, pos);
+      }
+      if (width == 32) {
+        return mlirDenseElementsAttrGetUInt32Value(*this, pos);
+      }
+      if (width == 64) {
+        return mlirDenseElementsAttrGetUInt64Value(*this, pos);
+      }
+    } else {
+      if (width == 1) {
+        return mlirDenseElementsAttrGetBoolValue(*this, pos);
+      }
+      if (width == 32) {
+        return mlirDenseElementsAttrGetInt32Value(*this, pos);
+      }
+      if (width == 64) {
+        return mlirDenseElementsAttrGetInt64Value(*this, pos);
+      }
+    }
+    throw SetPyError(PyExc_TypeError, "Unsupported integer type");
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+  }
+};
+
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+  static constexpr const char *pyClassName = "DictAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__len__", &PyDictAttribute::dunderLen);
+    c.def_static(
+        "get",
+        [](py::dict attributes, DefaultingPyMlirContext context) {
+          SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+          mlirNamedAttributes.reserve(attributes.size());
+          for (auto &it : attributes) {
+            auto &mlir_attr = it.second.cast<PyAttribute &>();
+            auto name = it.first.cast<std::string>();
+            mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+                mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
+                                  toMlirStringRef(name)),
+                mlir_attr));
+          }
+          MlirAttribute attr =
+              mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+                                    mlirNamedAttributes.data());
+          return PyDictAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued dict attribute");
+    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+      MlirAttribute attr =
+          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+      if (mlirAttributeIsNull(attr)) {
+        throw SetPyError(PyExc_KeyError,
+                         "attempt to access a non-existent attribute");
+      }
+      return PyAttribute(self.getContext(), attr);
+    });
+    c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+      if (index < 0 || index >= self.dunderLen()) {
+        throw SetPyError(PyExc_IndexError,
+                         "attempt to access out of bounds attribute");
+      }
+      MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+      return PyNamedAttribute(
+          namedAttr.attribute,
+          std::string(mlirIdentifierStr(namedAttr.name).data));
+    });
+  }
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+    : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+  static constexpr const char *pyClassName = "DenseFPElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  py::float_ dunderGetItem(intptr_t pos) {
+    if (pos < 0 || pos >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds element");
+    }
+
+    MlirType type = mlirAttributeGetType(*this);
+    type = mlirShapedTypeGetElementType(type);
+    // Dispatch element extraction to an appropriate C function based on the
+    // elemental type of the attribute. py::float_ is implicitly constructible
+    // from float and double.
+    // TODO: consider caching the type properties in the constructor to avoid
+    // querying them on each element access.
+    if (mlirTypeIsAF32(type)) {
+      return mlirDenseElementsAttrGetFloatValue(*this, pos);
+    }
+    if (mlirTypeIsAF64(type)) {
+      return mlirDenseElementsAttrGetDoubleValue(*this, pos);
+    }
+    throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+  }
+};
+
+class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+  static constexpr const char *pyClassName = "TypeAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirTypeAttrGet(value.get());
+          return PyTypeAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets a uniqued Type attribute");
+    c.def_property_readonly("value", [](PyTypeAttribute &self) {
+      return PyType(self.getContext()->getRef(),
+                    mlirTypeAttrGetValue(self.get()));
+    });
+  }
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+  static constexpr const char *pyClassName = "UnitAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return PyUnitAttribute(context->getRef(),
+                                 mlirUnitAttrGet(context->get()));
+        },
+        py::arg("context") = py::none(), "Create a Unit attribute.");
+  }
+};
+
+} // namespace
+
+void mlir::python::populateIRAttributes(py::module &m) {
+  PyAffineMapAttribute::bind(m);
+  PyArrayAttribute::bind(m);
+  PyArrayAttribute::PyArrayAttributeIterator::bind(m);
+  PyBoolAttribute::bind(m);
+  PyDenseElementsAttribute::bind(m);
+  PyDenseFPElementsAttribute::bind(m);
+  PyDenseIntElementsAttribute::bind(m);
+  PyDictAttribute::bind(m);
+  PyFlatSymbolRefAttribute::bind(m);
+  PyFloatAttribute::bind(m);
+  PyIntegerAttribute::bind(m);
+  PyStringAttribute::bind(m);
+  PyTypeAttribute::bind(m);
+  PyUnitAttribute::bind(m);
+}

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
similarity index 52%
rename from mlir/lib/Bindings/Python/IRModules.cpp
rename to mlir/lib/Bindings/Python/IRCore.cpp
index 6b4e5434d1d7..9d87aa52f7c8 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,16 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "IRModules.h"
+#include "IRModule.h"
 
 #include "Globals.h"
 #include "PybindUtils.h"
 
-#include "mlir-c/AffineMap.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
-#include "mlir-c/IntegerSet.h"
 #include "mlir-c/Registration.h"
 #include "llvm/ADT/SmallVector.h"
 #include <pybind11/stl.h>
@@ -138,12 +136,6 @@ py::object classmethod(Func f, Args... args) {
   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
 }
 
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
-  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
-         mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
-}
-
 static py::object
 createCustomDialectWrapper(const std::string &dialectNamespace,
                            py::object dialectDescriptor) {
@@ -161,21 +153,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
-template <typename PermutationTy>
-static bool isPermutation(std::vector<PermutationTy> permutation) {
-  llvm::SmallVector<bool, 8> seen(permutation.size(), false);
-  for (auto val : permutation) {
-    if (val < permutation.size()) {
-      if (seen[val])
-        return false;
-      seen[val] = true;
-      continue;
-    }
-    return false;
-  }
-  return true;
-}
-
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -1466,7 +1443,8 @@ namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
 /// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy> class PyConcreteValue : public PyValue {
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
 public:
   // Derived classes must define statics for:
   //   IsAFunctionTy isaFunction
@@ -1717,1910 +1695,169 @@ class PyOpAttributeMap {
 } // end namespace
 
 //------------------------------------------------------------------------------
-// Builtin attribute subclasses.
+// Populates the core exports of the 'ir' submodule.
 //------------------------------------------------------------------------------
 
-namespace {
-
-/// CRTP base classes for Python attributes that subclass Attribute and should
-/// be castable from it (i.e. via something like StringAttr(attr)).
-/// By default, attribute class hierarchies are one level deep (i.e. a
-/// concrete attribute class extends PyAttribute); however, intermediate
-/// python-visible base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirAttribute);
-
-  PyConcreteAttribute() = default;
-  PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
-      : BaseTy(std::move(contextRef), attr) {}
-  PyConcreteAttribute(PyAttribute &orig)
-      : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
-
-  static MlirAttribute castFrom(PyAttribute &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
-    cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
-  static constexpr const char *pyClassName = "AffineMapAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyAffineMap &affineMap) {
-          MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
-          return PyAffineMapAttribute(affineMap.getContext(), attr);
-        },
-        py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
-  }
-};
-
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
-  static constexpr const char *pyClassName = "ArrayAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  class PyArrayAttributeIterator {
-  public:
-    PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
-
-    PyArrayAttributeIterator &dunderIter() { return *this; }
-
-    PyAttribute dunderNext() {
-      if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
-        throw py::stop_iteration();
-      }
-      return PyAttribute(attr.getContext(),
-                         mlirArrayAttrGetElement(attr.get(), nextIndex++));
-    }
-
-    static void bind(py::module &m) {
-      py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
-          .def("__iter__", &PyArrayAttributeIterator::dunderIter)
-          .def("__next__", &PyArrayAttributeIterator::dunderNext);
-    }
-
-  private:
-    PyAttribute attr;
-    int nextIndex = 0;
-  };
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](py::list attributes, DefaultingPyMlirContext context) {
-          SmallVector<MlirAttribute> mlirAttributes;
-          mlirAttributes.reserve(py::len(attributes));
-          for (auto attribute : attributes) {
-            try {
-              mlirAttributes.push_back(attribute.cast<PyAttribute>());
-            } catch (py::cast_error &err) {
-              std::string msg = std::string("Invalid attribute when attempting "
-                                            "to create an ArrayAttribute (") +
-                                err.what() + ")";
-              throw py::cast_error(msg);
-            } catch (py::reference_cast_error &err) {
-              // This exception seems thrown when the value is "None".
-              std::string msg =
-                  std::string("Invalid attribute (None?) when attempting to "
-                              "create an ArrayAttribute (") +
-                  err.what() + ")";
-              throw py::cast_error(msg);
+void mlir::python::populateIRCore(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of MlirContext
+  //----------------------------------------------------------------------------
+  py::class_<PyMlirContext>(m, "Context")
+      .def(py::init<>(&PyMlirContext::createNewContextForInit))
+      .def_static("_get_live_count", &PyMlirContext::getLiveCount)
+      .def("_get_context_again",
+           [](PyMlirContext &self) {
+             PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+             return ref.releaseObject();
+           })
+      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyMlirContext::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def("__enter__", &PyMlirContext::contextEnter)
+      .def("__exit__", &PyMlirContext::contextExit)
+      .def_property_readonly_static(
+          "current",
+          [](py::object & /*class*/) {
+            auto *context = PyThreadContextEntry::getDefaultContext();
+            if (!context)
+              throw SetPyError(PyExc_ValueError, "No current Context");
+            return context;
+          },
+          "Gets the Context bound to the current thread or raises ValueError")
+      .def_property_readonly(
+          "dialects",
+          [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+          "Gets a container for accessing dialects by name")
+      .def_property_readonly(
+          "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+          "Alias for 'dialect'")
+      .def(
+          "get_dialect_descriptor",
+          [=](PyMlirContext &self, std::string &name) {
+            MlirDialect dialect = mlirContextGetOrLoadDialect(
+                self.get(), {name.data(), name.size()});
+            if (mlirDialectIsNull(dialect)) {
+              throw SetPyError(PyExc_ValueError,
+                               Twine("Dialect '") + name + "' not found");
             }
-          }
-          MlirAttribute attr = mlirArrayAttrGet(
-              context->get(), mlirAttributes.size(), mlirAttributes.data());
-          return PyArrayAttribute(context->getRef(), attr);
-        },
-        py::arg("attributes"), py::arg("context") = py::none(),
-        "Gets a uniqued Array attribute");
-    c.def("__getitem__",
-          [](PyArrayAttribute &arr, intptr_t i) {
-            if (i >= mlirArrayAttrGetNumElements(arr))
-              throw py::index_error("ArrayAttribute index out of range");
-            return PyAttribute(arr.getContext(),
-                               mlirArrayAttrGetElement(arr, i));
-          })
-        .def("__len__",
-             [](const PyArrayAttribute &arr) {
-               return mlirArrayAttrGetNumElements(arr);
-             })
-        .def("__iter__", [](const PyArrayAttribute &arr) {
-          return PyArrayAttributeIterator(arr);
-        });
-  }
-};
-
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
-  static constexpr const char *pyClassName = "FloatAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &type, double value, DefaultingPyLocation loc) {
-          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirAttributeIsNull(attr)) {
-            throw SetPyError(PyExc_ValueError,
-                             Twine("invalid '") +
-                                 py::repr(py::cast(type)).cast<std::string>() +
-                                 "' and expected floating point type.");
-          }
-          return PyFloatAttribute(type.getContext(), attr);
-        },
-        py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
-        "Gets an uniqued float point attribute associated to a type");
-    c.def_static(
-        "get_f32",
-        [](double value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context->get(), mlirF32TypeGet(context->get()), value);
-          return PyFloatAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets an uniqued float point attribute associated to a f32 type");
-    c.def_static(
-        "get_f64",
-        [](double value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context->get(), mlirF64TypeGet(context->get()), value);
-          return PyFloatAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets an uniqued float point attribute associated to a f64 type");
-    c.def_property_readonly(
-        "value",
-        [](PyFloatAttribute &self) {
-          return mlirFloatAttrGetValueDouble(self);
-        },
-        "Returns the value of the float point attribute");
-  }
-};
-
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
-  static constexpr const char *pyClassName = "IntegerAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &type, int64_t value) {
-          MlirAttribute attr = mlirIntegerAttrGet(type, value);
-          return PyIntegerAttribute(type.getContext(), attr);
-        },
-        py::arg("type"), py::arg("value"),
-        "Gets an uniqued integer attribute associated to a type");
-    c.def_property_readonly(
-        "value",
-        [](PyIntegerAttribute &self) {
-          return mlirIntegerAttrGetValueInt(self);
-        },
-        "Returns the value of the integer attribute");
-  }
-};
-
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
-  static constexpr const char *pyClassName = "BoolAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](bool value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
-          return PyBoolAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets an uniqued bool attribute");
-    c.def_property_readonly(
-        "value",
-        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
-        "Returns the value of the bool attribute");
-  }
-};
-
-class PyFlatSymbolRefAttribute
-    : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
-  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::string value, DefaultingPyMlirContext context) {
-          MlirAttribute attr =
-              mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
-          return PyFlatSymbolRefAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets a uniqued FlatSymbolRef attribute");
-    c.def_property_readonly(
-        "value",
-        [](PyFlatSymbolRefAttribute &self) {
-          MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
-          return py::str(stringRef.data, stringRef.length);
-        },
-        "Returns the value of the FlatSymbolRef attribute as a string");
-  }
-};
-
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
-  static constexpr const char *pyClassName = "StringAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::string value, DefaultingPyMlirContext context) {
-          MlirAttribute attr =
-              mlirStringAttrGet(context->get(), toMlirStringRef(value));
-          return PyStringAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets a uniqued string attribute");
-    c.def_static(
-        "get_typed",
-        [](PyType &type, std::string value) {
-          MlirAttribute attr =
-              mlirStringAttrTypedGet(type, toMlirStringRef(value));
-          return PyStringAttribute(type.getContext(), attr);
-        },
-
-        "Gets a uniqued string attribute associated to a type");
-    c.def_property_readonly(
-        "value",
-        [](PyStringAttribute &self) {
-          MlirStringRef stringRef = mlirStringAttrGetValue(self);
-          return py::str(stringRef.data, stringRef.length);
-        },
-        "Returns the value of the string attribute");
-  }
-};
-
-// TODO: Support construction of bool elements.
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
-    : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
-  static constexpr const char *pyClassName = "DenseElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static PyDenseElementsAttribute
-  getFromBuffer(py::buffer array, bool signless,
-                DefaultingPyMlirContext contextWrapper) {
-    // Request a contiguous view. In exotic cases, this will cause a copy.
-    int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
-    Py_buffer *view = new Py_buffer();
-    if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
-      delete view;
-      throw py::error_already_set();
-    }
-    py::buffer_info arrayInfo(view);
-
-    MlirContext context = contextWrapper->get();
-    // Switch on the types that can be bulk loaded between the Python and
-    // MLIR-C APIs.
-    // See: https://docs.python.org/3/library/struct.html#format-characters
-    if (arrayInfo.format == "f") {
-      // f32
-      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrFloatGet,
-                   mlirF32TypeGet(context), arrayInfo));
-    } else if (arrayInfo.format == "d") {
-      // f64
-      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrDoubleGet,
-                   mlirF64TypeGet(context), arrayInfo));
-    } else if (isSignedIntegerFormat(arrayInfo.format)) {
-      if (arrayInfo.itemsize == 4) {
-        // i32
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
-                                        : mlirIntegerTypeSignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt32Get,
-                                                 elementType, arrayInfo));
-      } else if (arrayInfo.itemsize == 8) {
-        // i64
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
-                                        : mlirIntegerTypeSignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt64Get,
-                                                 elementType, arrayInfo));
-      }
-    } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
-      if (arrayInfo.itemsize == 4) {
-        // unsigned i32
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 32)
-                                   : mlirIntegerTypeUnsignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt32Get,
-                                                 elementType, arrayInfo));
-      } else if (arrayInfo.itemsize == 8) {
-        // unsigned i64
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 64)
-                                   : mlirIntegerTypeUnsignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt64Get,
-                                                 elementType, arrayInfo));
-      }
-    }
-
-    // TODO: Fall back to string-based get.
-    std::string message = "unimplemented array format conversion from format: ";
-    message.append(arrayInfo.format);
-    throw SetPyError(PyExc_ValueError, message);
-  }
-
-  static PyDenseElementsAttribute getSplat(PyType shapedType,
-                                           PyAttribute &elementAttr) {
-    auto contextWrapper =
-        PyMlirContext::forContext(mlirTypeGetContext(shapedType));
-    if (!mlirAttributeIsAInteger(elementAttr) &&
-        !mlirAttributeIsAFloat(elementAttr)) {
-      std::string message = "Illegal element type for DenseElementsAttr: ";
-      message.append(py::repr(py::cast(elementAttr)));
-      throw SetPyError(PyExc_ValueError, message);
-    }
-    if (!mlirTypeIsAShaped(shapedType) ||
-        !mlirShapedTypeHasStaticShape(shapedType)) {
-      std::string message =
-          "Expected a static ShapedType for the shaped_type parameter: ";
-      message.append(py::repr(py::cast(shapedType)));
-      throw SetPyError(PyExc_ValueError, message);
-    }
-    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
-    MlirType attrType = mlirAttributeGetType(elementAttr);
-    if (!mlirTypeEqual(shapedElementType, attrType)) {
-      std::string message =
-          "Shaped element type and attribute type must be equal: shaped=";
-      message.append(py::repr(py::cast(shapedType)));
-      message.append(", element=");
-      message.append(py::repr(py::cast(elementAttr)));
-      throw SetPyError(PyExc_ValueError, message);
-    }
-
-    MlirAttribute elements =
-        mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
-    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
-  }
+            return PyDialectDescriptor(self.getRef(), dialect);
+          },
+          "Gets or loads a dialect by name, returning its descriptor object")
+      .def_property(
+          "allow_unregistered_dialects",
+          [](PyMlirContext &self) -> bool {
+            return mlirContextGetAllowUnregisteredDialects(self.get());
+          },
+          [](PyMlirContext &self, bool value) {
+            mlirContextSetAllowUnregisteredDialects(self.get(), value);
+          });
 
-  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
-  py::buffer_info accessBuffer() {
-    MlirType shapedType = mlirAttributeGetType(*this);
-    MlirType elementType = mlirShapedTypeGetElementType(shapedType);
-
-    if (mlirTypeIsAF32(elementType)) {
-      // f32
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
-    } else if (mlirTypeIsAF64(elementType)) {
-      // f64
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 32) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
-      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
-      }
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 64) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
-      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
-      }
-    }
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialectDescriptor
+  //----------------------------------------------------------------------------
+  py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+      .def_property_readonly("namespace",
+                             [](PyDialectDescriptor &self) {
+                               MlirStringRef ns =
+                                   mlirDialectGetNamespace(self.get());
+                               return py::str(ns.data, ns.length);
+                             })
+      .def("__repr__", [](PyDialectDescriptor &self) {
+        MlirStringRef ns = mlirDialectGetNamespace(self.get());
+        std::string repr("<DialectDescriptor ");
+        repr.append(ns.data, ns.length);
+        repr.append(">");
+        return repr;
+      });
 
-    std::string message = "unimplemented array format.";
-    throw SetPyError(PyExc_ValueError, message);
-  }
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialects
+  //----------------------------------------------------------------------------
+  py::class_<PyDialects>(m, "Dialects")
+      .def("__getitem__",
+           [=](PyDialects &self, std::string keyName) {
+             MlirDialect dialect =
+                 self.getDialectForKey(keyName, /*attrError=*/false);
+             py::object descriptor =
+                 py::cast(PyDialectDescriptor{self.getContext(), dialect});
+             return createCustomDialectWrapper(keyName, std::move(descriptor));
+           })
+      .def("__getattr__", [=](PyDialects &self, std::string attrName) {
+        MlirDialect dialect =
+            self.getDialectForKey(attrName, /*attrError=*/true);
+        py::object descriptor =
+            py::cast(PyDialectDescriptor{self.getContext(), dialect});
+        return createCustomDialectWrapper(attrName, std::move(descriptor));
+      });
 
-  static void bindDerived(ClassTy &c) {
-    c.def("__len__", &PyDenseElementsAttribute::dunderLen)
-        .def_static("get", PyDenseElementsAttribute::getFromBuffer,
-                    py::arg("array"), py::arg("signless") = true,
-                    py::arg("context") = py::none(),
-                    "Gets from a buffer or ndarray")
-        .def_static("get_splat", PyDenseElementsAttribute::getSplat,
-                    py::arg("shaped_type"), py::arg("element_attr"),
-                    "Gets a DenseElementsAttr where all values are the same")
-        .def_property_readonly("is_splat",
-                               [](PyDenseElementsAttribute &self) -> bool {
-                                 return mlirDenseElementsAttrIsSplat(self);
-                               })
-        .def_buffer(&PyDenseElementsAttribute::accessBuffer);
-  }
+  //----------------------------------------------------------------------------
+  // Mapping of PyDialect
+  //----------------------------------------------------------------------------
+  py::class_<PyDialect>(m, "Dialect")
+      .def(py::init<py::object>(), "descriptor")
+      .def_property_readonly(
+          "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
+      .def("__repr__", [](py::object self) {
+        auto clazz = self.attr("__class__");
+        return py::str("<Dialect ") +
+               self.attr("descriptor").attr("namespace") + py::str(" (class ") +
+               clazz.attr("__module__") + py::str(".") +
+               clazz.attr("__name__") + py::str(")>");
+      });
 
-private:
-  template <typename ElementTy>
-  static MlirAttribute
-  bulkLoad(MlirContext context,
-           MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
-           MlirType mlirElementType, py::buffer_info &arrayInfo) {
-    SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
-                                  arrayInfo.shape.begin() + arrayInfo.ndim);
-    auto shapedType =
-        mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
-    intptr_t numElements = arrayInfo.size;
-    const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
-    return ctor(shapedType, numElements, contents);
-  }
-
-  static bool isUnsignedIntegerFormat(const std::string &format) {
-    if (format.empty())
-      return false;
-    char code = format[0];
-    return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
-           code == 'Q';
-  }
-
-  static bool isSignedIntegerFormat(const std::string &format) {
-    if (format.empty())
-      return false;
-    char code = format[0];
-    return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
-           code == 'q';
-  }
-
-  template <typename Type>
-  py::buffer_info bufferInfo(MlirType shapedType,
-                             Type (*value)(MlirAttribute, intptr_t)) {
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
-    // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access below.
-    Type *data = static_cast<Type *>(
-        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    // Prepare the shape for the buffer_info.
-    SmallVector<intptr_t, 4> shape;
-    for (intptr_t i = 0; i < rank; ++i)
-      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
-    // Prepare the strides for the buffer_info.
-    SmallVector<intptr_t, 4> strides;
-    intptr_t strideFactor = 1;
-    for (intptr_t i = 1; i < rank; ++i) {
-      strideFactor = 1;
-      for (intptr_t j = i; j < rank; ++j) {
-        strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
-      }
-      strides.push_back(sizeof(Type) * strideFactor);
-    }
-    strides.push_back(sizeof(Type));
-    return py::buffer_info(data, sizeof(Type),
-                           py::format_descriptor<Type>::format(), rank, shape,
-                           strides, /*readonly=*/true);
-  }
-}; // namespace
-
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
-    : public PyConcreteAttribute<PyDenseIntElementsAttribute,
-                                 PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
-  static constexpr const char *pyClassName = "DenseIntElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  /// Returns the element at the given linear position. Asserts if the index is
-  /// out of range.
-  py::int_ dunderGetItem(intptr_t pos) {
-    if (pos < 0 || pos >= dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to access out of bounds element");
-    }
-
-    MlirType type = mlirAttributeGetType(*this);
-    type = mlirShapedTypeGetElementType(type);
-    assert(mlirTypeIsAInteger(type) &&
-           "expected integer element type in dense int elements attribute");
-    // Dispatch element extraction to an appropriate C function based on the
-    // elemental type of the attribute. py::int_ is implicitly constructible
-    // from any C++ integral type and handles bitwidth correctly.
-    // TODO: consider caching the type properties in the constructor to avoid
-    // querying them on each element access.
-    unsigned width = mlirIntegerTypeGetWidth(type);
-    bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
-    if (isUnsigned) {
-      if (width == 1) {
-        return mlirDenseElementsAttrGetBoolValue(*this, pos);
-      }
-      if (width == 32) {
-        return mlirDenseElementsAttrGetUInt32Value(*this, pos);
-      }
-      if (width == 64) {
-        return mlirDenseElementsAttrGetUInt64Value(*this, pos);
-      }
-    } else {
-      if (width == 1) {
-        return mlirDenseElementsAttrGetBoolValue(*this, pos);
-      }
-      if (width == 32) {
-        return mlirDenseElementsAttrGetInt32Value(*this, pos);
-      }
-      if (width == 64) {
-        return mlirDenseElementsAttrGetInt64Value(*this, pos);
-      }
-    }
-    throw SetPyError(PyExc_TypeError, "Unsupported integer type");
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
-  }
-};
-
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
-  static constexpr const char *pyClassName = "DictAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__len__", &PyDictAttribute::dunderLen);
-    c.def_static(
-        "get",
-        [](py::dict attributes, DefaultingPyMlirContext context) {
-          SmallVector<MlirNamedAttribute> mlirNamedAttributes;
-          mlirNamedAttributes.reserve(attributes.size());
-          for (auto &it : attributes) {
-            auto &mlir_attr = it.second.cast<PyAttribute &>();
-            auto name = it.first.cast<std::string>();
-            mlirNamedAttributes.push_back(mlirNamedAttributeGet(
-                mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
-                                  toMlirStringRef(name)),
-                mlir_attr));
-          }
-          MlirAttribute attr =
-              mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
-                                    mlirNamedAttributes.data());
-          return PyDictAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets an uniqued dict attribute");
-    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
-      MlirAttribute attr =
-          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-      if (mlirAttributeIsNull(attr)) {
-        throw SetPyError(PyExc_KeyError,
-                         "attempt to access a non-existent attribute");
-      }
-      return PyAttribute(self.getContext(), attr);
-    });
-    c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
-      if (index < 0 || index >= self.dunderLen()) {
-        throw SetPyError(PyExc_IndexError,
-                         "attempt to access out of bounds attribute");
-      }
-      MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
-      return PyNamedAttribute(
-          namedAttr.attribute,
-          std::string(mlirIdentifierStr(namedAttr.name).data));
-    });
-  }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
-    : public PyConcreteAttribute<PyDenseFPElementsAttribute,
-                                 PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
-  static constexpr const char *pyClassName = "DenseFPElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  py::float_ dunderGetItem(intptr_t pos) {
-    if (pos < 0 || pos >= dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to access out of bounds element");
-    }
-
-    MlirType type = mlirAttributeGetType(*this);
-    type = mlirShapedTypeGetElementType(type);
-    // Dispatch element extraction to an appropriate C function based on the
-    // elemental type of the attribute. py::float_ is implicitly constructible
-    // from float and double.
-    // TODO: consider caching the type properties in the constructor to avoid
-    // querying them on each element access.
-    if (mlirTypeIsAF32(type)) {
-      return mlirDenseElementsAttrGetFloatValue(*this, pos);
-    }
-    if (mlirTypeIsAF64(type)) {
-      return mlirDenseElementsAttrGetDoubleValue(*this, pos);
-    }
-    throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
-  }
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
-  static constexpr const char *pyClassName = "TypeAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirTypeAttrGet(value.get());
-          return PyTypeAttribute(context->getRef(), attr);
-        },
-        py::arg("value"), py::arg("context") = py::none(),
-        "Gets a uniqued Type attribute");
-    c.def_property_readonly("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext()->getRef(),
-                    mlirTypeAttrGetValue(self.get()));
-    });
-  }
-};
-
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
-  static constexpr const char *pyClassName = "UnitAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          return PyUnitAttribute(context->getRef(),
-                                 mlirUnitAttrGet(context->get()));
-        },
-        py::arg("context") = py::none(), "Create a Unit attribute.");
-  }
-};
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// Builtin type subclasses.
-//------------------------------------------------------------------------------
-
-namespace {
-
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirType);
-
-  PyConcreteType() = default;
-  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {}
-  PyConcreteType(PyType &orig)
-      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
-  static MlirType castFrom(PyType &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
-    cls.def_static("isinstance", [](PyType &otherType) -> bool {
-      return DerivedTy::isaFunction(otherType);
-    });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
-  static constexpr const char *pyClassName = "IntegerType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_signless",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        py::arg("width"), py::arg("context") = py::none(),
-        "Create a signless integer type");
-    c.def_static(
-        "get_signed",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        py::arg("width"), py::arg("context") = py::none(),
-        "Create a signed integer type");
-    c.def_static(
-        "get_unsigned",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        py::arg("width"), py::arg("context") = py::none(),
-        "Create an unsigned integer type");
-    c.def_property_readonly(
-        "width",
-        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
-        "Returns the width of the integer type");
-    c.def_property_readonly(
-        "is_signless",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSignless(self);
-        },
-        "Returns whether this is a signless integer");
-    c.def_property_readonly(
-        "is_signed",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSigned(self);
-        },
-        "Returns whether this is a signed integer");
-    c.def_property_readonly(
-        "is_unsigned",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsUnsigned(self);
-        },
-        "Returns whether this is an unsigned integer");
-  }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
-  static constexpr const char *pyClassName = "IndexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirIndexTypeGet(context->get());
-          return PyIndexType(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a index type.");
-  }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
-  static constexpr const char *pyClassName = "BF16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirBF16TypeGet(context->get());
-          return PyBF16Type(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a bf16 type.");
-  }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
-  static constexpr const char *pyClassName = "F16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF16TypeGet(context->get());
-          return PyF16Type(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a f16 type.");
-  }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
-  static constexpr const char *pyClassName = "F32Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF32TypeGet(context->get());
-          return PyF32Type(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a f32 type.");
-  }
-};
-
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
-  static constexpr const char *pyClassName = "F64Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF64TypeGet(context->get());
-          return PyF64Type(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a f64 type.");
-  }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
-  static constexpr const char *pyClassName = "NoneType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirNoneTypeGet(context->get());
-          return PyNoneType(context->getRef(), t);
-        },
-        py::arg("context") = py::none(), "Create a none type.");
-  }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
-  static constexpr const char *pyClassName = "ComplexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType) {
-          // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType)) {
-            MlirType t = mlirComplexTypeGet(elementType);
-            return PyComplexType(elementType.getContext(), t);
-          }
-          throw SetPyError(
-              PyExc_ValueError,
-              Twine("invalid '") +
-                  py::repr(py::cast(elementType)).cast<std::string>() +
-                  "' and expected floating point or integer type.");
-        },
-        "Create a complex type");
-    c.def_property_readonly(
-        "element_type",
-        [](PyComplexType &self) -> PyType {
-          MlirType t = mlirComplexTypeGetElementType(self);
-          return PyType(self.getContext(), t);
-        },
-        "Returns element type.");
-  }
-};
-
-class PyShapedType : public PyConcreteType<PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
-  static constexpr const char *pyClassName = "ShapedType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly(
-        "element_type",
-        [](PyShapedType &self) {
-          MlirType t = mlirShapedTypeGetElementType(self);
-          return PyType(self.getContext(), t);
-        },
-        "Returns the element type of the shaped type.");
-    c.def_property_readonly(
-        "has_rank",
-        [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
-        "Returns whether the given shaped type is ranked.");
-    c.def_property_readonly(
-        "rank",
-        [](PyShapedType &self) {
-          self.requireHasRank();
-          return mlirShapedTypeGetRank(self);
-        },
-        "Returns the rank of the given ranked shaped type.");
-    c.def_property_readonly(
-        "has_static_shape",
-        [](PyShapedType &self) -> bool {
-          return mlirShapedTypeHasStaticShape(self);
-        },
-        "Returns whether the given shaped type has a static shape.");
-    c.def(
-        "is_dynamic_dim",
-        [](PyShapedType &self, intptr_t dim) -> bool {
-          self.requireHasRank();
-          return mlirShapedTypeIsDynamicDim(self, dim);
-        },
-        "Returns whether the dim-th dimension of the given shaped type is "
-        "dynamic.");
-    c.def(
-        "get_dim_size",
-        [](PyShapedType &self, intptr_t dim) {
-          self.requireHasRank();
-          return mlirShapedTypeGetDimSize(self, dim);
-        },
-        "Returns the dim-th dimension of the given ranked shaped type.");
-    c.def_static(
-        "is_dynamic_size",
-        [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
-        "Returns whether the given dimension size indicates a dynamic "
-        "dimension.");
-    c.def(
-        "is_dynamic_stride_or_offset",
-        [](PyShapedType &self, int64_t val) -> bool {
-          self.requireHasRank();
-          return mlirShapedTypeIsDynamicStrideOrOffset(val);
-        },
-        "Returns whether the given value is used as a placeholder for dynamic "
-        "strides and offsets in shaped types.");
-  }
-
-private:
-  void requireHasRank() {
-    if (!mlirShapedTypeHasRank(*this)) {
-      throw SetPyError(
-          PyExc_ValueError,
-          "calling this method requires that the type has a rank.");
-    }
-  }
-};
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
-  static constexpr const char *pyClassName = "VectorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point or integer type.");
-          }
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
-  }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
-    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
-  static constexpr const char *pyClassName = "RankedTensorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          MlirType t = mlirRankedTensorTypeGetChecked(
-              loc, shape.size(), shape.data(), elementType);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point, integer, vector or "
-                    "complex "
-                    "type.");
-          }
-          return PyRankedTensorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
-        "Create a ranked tensor type");
-  }
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
-    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
-  static constexpr const char *pyClassName = "UnrankedTensorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType, DefaultingPyLocation loc) {
-          MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point, integer, vector or "
-                    "complex "
-                    "type.");
-          }
-          return PyUnrankedTensorType(elementType.getContext(), t);
-        },
-        py::arg("element_type"), py::arg("loc") = py::none(),
-        "Create a unranked tensor type");
-  }
-};
-
-class PyMemRefLayoutMapList;
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
-  static constexpr const char *pyClassName = "MemRefType";
-  using PyConcreteType::PyConcreteType;
-
-  PyMemRefLayoutMapList getLayout();
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](std::vector<int64_t> shape, PyType &elementType,
-            std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           SmallVector<MlirAffineMap> maps;
-           maps.reserve(layout.size());
-           for (PyAffineMap &map : layout)
-             maps.push_back(map);
-
-           MlirAttribute memSpaceAttr = {};
-           if (memorySpace)
-             memSpaceAttr = *memorySpace;
-
-           MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
-                                                 shape.data(), maps.size(),
-                                                 maps.data(), memSpaceAttr);
-           // TODO: Rework error reporting once diagnostic engine is exposed
-           // in C API.
-           if (mlirTypeIsNull(t)) {
-             throw SetPyError(
-                 PyExc_ValueError,
-                 Twine("invalid '") +
-                     py::repr(py::cast(elementType)).cast<std::string>() +
-                     "' and expected floating point, integer, vector or "
-                     "complex "
-                     "type.");
-           }
-           return PyMemRefType(elementType.getContext(), t);
-         },
-         py::arg("shape"), py::arg("element_type"),
-         py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
-         py::arg("loc") = py::none(), "Create a memref type")
-        .def_property_readonly("layout", &PyMemRefType::getLayout,
-                               "The list of layout maps of the MemRef type.")
-        .def_property_readonly(
-            "memory_space",
-            [](PyMemRefType &self) -> PyAttribute {
-              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              return PyAttribute(self.getContext(), a);
-            },
-            "Returns the memory space of the given MemRef type.");
-  }
-};
-
-/// A list of affine layout maps in a memref type. Internally, these are stored
-/// as consecutive elements, random access is cheap. Both the type and the maps
-/// are owned by the context, no need to worry about lifetime extension.
-class PyMemRefLayoutMapList
-    : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
-public:
-  static constexpr const char *pyClassName = "MemRefLayoutMapList";
-
-  PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
-                        intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
-                  step),
-        memref(type) {}
-
-  intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
-
-  PyAffineMap getElement(intptr_t index) {
-    return PyAffineMap(memref.getContext(),
-                       mlirMemRefTypeGetAffineMap(memref, index));
-  }
-
-  PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
-                              intptr_t step) {
-    return PyMemRefLayoutMapList(memref, startIndex, length, step);
-  }
-
-private:
-  PyMemRefType memref;
-};
-
-PyMemRefLayoutMapList PyMemRefType::getLayout() {
-  return PyMemRefLayoutMapList(*this);
-}
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
-    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
-  static constexpr const char *pyClassName = "UnrankedMemRefType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](PyType &elementType, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           MlirAttribute memSpaceAttr = {};
-           if (memorySpace)
-             memSpaceAttr = *memorySpace;
-
-           MlirType t =
-               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
-           // TODO: Rework error reporting once diagnostic engine is exposed
-           // in C API.
-           if (mlirTypeIsNull(t)) {
-             throw SetPyError(
-                 PyExc_ValueError,
-                 Twine("invalid '") +
-                     py::repr(py::cast(elementType)).cast<std::string>() +
-                     "' and expected floating point, integer, vector or "
-                     "complex "
-                     "type.");
-           }
-           return PyUnrankedMemRefType(elementType.getContext(), t);
-         },
-         py::arg("element_type"), py::arg("memory_space"),
-         py::arg("loc") = py::none(), "Create a unranked memref type")
-        .def_property_readonly(
-            "memory_space",
-            [](PyUnrankedMemRefType &self) -> PyAttribute {
-              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              return PyAttribute(self.getContext(), a);
-            },
-            "Returns the memory space of the given Unranked MemRef type.");
-  }
-};
-
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
-  static constexpr const char *pyClassName = "TupleType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_tuple",
-        [](py::list elementList, DefaultingPyMlirContext context) {
-          intptr_t num = py::len(elementList);
-          // Mapping py::list to SmallVector.
-          SmallVector<MlirType, 4> elements;
-          for (auto element : elementList)
-            elements.push_back(element.cast<PyType>());
-          MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        py::arg("elements"), py::arg("context") = py::none(),
-        "Create a tuple type");
-    c.def(
-        "get_type",
-        [](PyTupleType &self, intptr_t pos) -> PyType {
-          MlirType t = mlirTupleTypeGetType(self, pos);
-          return PyType(self.getContext(), t);
-        },
-        "Returns the pos-th type in the tuple type.");
-    c.def_property_readonly(
-        "num_types",
-        [](PyTupleType &self) -> intptr_t {
-          return mlirTupleTypeGetNumTypes(self);
-        },
-        "Returns the number of types contained in a tuple.");
-  }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
-  static constexpr const char *pyClassName = "FunctionType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
-           DefaultingPyMlirContext context) {
-          SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
-          SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
-          MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
-                                           inputsRaw.data(), resultsRaw.size(),
-                                           resultsRaw.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
-        "Gets a FunctionType from a list of input and result types");
-    c.def_property_readonly(
-        "inputs",
-        [](PyFunctionType &self) {
-          MlirType t = self;
-          auto contextRef = self.getContext();
-          py::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
-               ++i) {
-            types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
-          }
-          return types;
-        },
-        "Returns the list of input types in the FunctionType.");
-    c.def_property_readonly(
-        "results",
-        [](PyFunctionType &self) {
-          auto contextRef = self.getContext();
-          py::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
-               ++i) {
-            types.append(
-                PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
-          }
-          return types;
-        },
-        "Returns the list of result types in the FunctionType.");
-  }
-};
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// PyAffineExpr and subclasses.
-//------------------------------------------------------------------------------
-
-namespace {
-/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
-/// and should be castable from it. Intermediate hierarchy classes can be
-/// modeled by specifying BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAffineExpr>
-class PyConcreteAffineExpr : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  // and redefine bindDerived.
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirAffineExpr);
-
-  PyConcreteAffineExpr() = default;
-  PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
-      : BaseTy(std::move(contextRef), affineExpr) {}
-  PyConcreteAffineExpr(PyAffineExpr &orig)
-      : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
-
-  static MlirAffineExpr castFrom(PyAffineExpr &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError,
-                       Twine("Cannot cast affine expression to ") +
-                           DerivedTy::pyClassName + " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(py::init<PyAffineExpr &>());
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
-class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
-  static constexpr const char *pyClassName = "AffineConstantExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineConstantExpr get(intptr_t value,
-                                  DefaultingPyMlirContext context) {
-    MlirAffineExpr affineExpr =
-        mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
-    return PyAffineConstantExpr(context->getRef(), affineExpr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
-      return mlirAffineConstantExprGetValue(self);
-    });
-  }
-};
-
-class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
-  static constexpr const char *pyClassName = "AffineDimExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
-    MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
-    return PyAffineDimExpr(context->getRef(), affineExpr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("position", [](PyAffineDimExpr &self) {
-      return mlirAffineDimExprGetPosition(self);
-    });
-  }
-};
-
-class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
-  static constexpr const char *pyClassName = "AffineSymbolExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
-    MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
-    return PyAffineSymbolExpr(context->getRef(), affineExpr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
-                 py::arg("context") = py::none());
-    c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
-      return mlirAffineSymbolExprGetPosition(self);
-    });
-  }
-};
-
-class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
-  static constexpr const char *pyClassName = "AffineBinaryExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  PyAffineExpr lhs() {
-    MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
-    return PyAffineExpr(getContext(), lhsExpr);
-  }
-
-  PyAffineExpr rhs() {
-    MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
-    return PyAffineExpr(getContext(), rhsExpr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
-    c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
-  }
-};
-
-class PyAffineAddExpr
-    : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
-  static constexpr const char *pyClassName = "AffineAddExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
-    MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
-    return PyAffineAddExpr(lhs.getContext(), expr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineAddExpr::get);
-  }
-};
-
-class PyAffineMulExpr
-    : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
-  static constexpr const char *pyClassName = "AffineMulExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
-    MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
-    return PyAffineMulExpr(lhs.getContext(), expr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineMulExpr::get);
-  }
-};
-
-class PyAffineModExpr
-    : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
-  static constexpr const char *pyClassName = "AffineModExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
-    MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
-    return PyAffineModExpr(lhs.getContext(), expr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineModExpr::get);
-  }
-};
-
-class PyAffineFloorDivExpr
-    : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
-  static constexpr const char *pyClassName = "AffineFloorDivExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
-    MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
-    return PyAffineFloorDivExpr(lhs.getContext(), expr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineFloorDivExpr::get);
-  }
-};
-
-class PyAffineCeilDivExpr
-    : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
-  static constexpr const char *pyClassName = "AffineCeilDivExpr";
-  using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
-  static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
-    MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
-    return PyAffineCeilDivExpr(lhs.getContext(), expr);
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyAffineCeilDivExpr::get);
-  }
-};
-} // namespace
-
-bool PyAffineExpr::operator==(const PyAffineExpr &other) {
-  return mlirAffineExprEqual(affineExpr, other.affineExpr);
-}
-
-py::object PyAffineExpr::getCapsule() {
-  return py::reinterpret_steal<py::object>(
-      mlirPythonAffineExprToCapsule(*this));
-}
-
-PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
-  MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
-  if (mlirAffineExprIsNull(rawAffineExpr))
-    throw py::error_already_set();
-  return PyAffineExpr(
-      PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
-      rawAffineExpr);
-}
-
-//------------------------------------------------------------------------------
-// PyAffineMap and utilities.
-//------------------------------------------------------------------------------
-
-namespace {
-/// A list of expressions contained in an affine map. Internally these are
-/// stored as a consecutive array leading to inexpensive random access. Both
-/// the map and the expression are owned by the context so we need not bother
-/// with lifetime extension.
-class PyAffineMapExprList
-    : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
-public:
-  static constexpr const char *pyClassName = "AffineExprList";
-
-  PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
-                      intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirAffineMapGetNumResults(map) : length,
-                  step),
-        affineMap(map) {}
-
-  intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
-
-  PyAffineExpr getElement(intptr_t pos) {
-    return PyAffineExpr(affineMap.getContext(),
-                        mlirAffineMapGetResult(affineMap, pos));
-  }
-
-  PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
-                            intptr_t step) {
-    return PyAffineMapExprList(affineMap, startIndex, length, step);
-  }
-
-private:
-  PyAffineMap affineMap;
-};
-} // end namespace
-
-bool PyAffineMap::operator==(const PyAffineMap &other) {
-  return mlirAffineMapEqual(affineMap, other.affineMap);
-}
-
-py::object PyAffineMap::getCapsule() {
-  return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
-}
-
-PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
-  MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
-  if (mlirAffineMapIsNull(rawAffineMap))
-    throw py::error_already_set();
-  return PyAffineMap(
-      PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
-      rawAffineMap);
-}
-
-//------------------------------------------------------------------------------
-// PyIntegerSet and utilities.
-//------------------------------------------------------------------------------
-
-class PyIntegerSetConstraint {
-public:
-  PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
-
-  PyAffineExpr getExpr() {
-    return PyAffineExpr(set.getContext(),
-                        mlirIntegerSetGetConstraint(set, pos));
-  }
-
-  bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
-
-  static void bind(py::module &m) {
-    py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
-        .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
-        .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
-  }
-
-private:
-  PyIntegerSet set;
-  intptr_t pos;
-};
-
-class PyIntegerSetConstraintList
-    : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
-public:
-  static constexpr const char *pyClassName = "IntegerSetConstraintList";
-
-  PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
-                             intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
-                  step),
-        set(set) {}
-
-  intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
-
-  PyIntegerSetConstraint getElement(intptr_t pos) {
-    return PyIntegerSetConstraint(set, pos);
-  }
-
-  PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
-                                   intptr_t step) {
-    return PyIntegerSetConstraintList(set, startIndex, length, step);
-  }
-
-private:
-  PyIntegerSet set;
-};
-
-bool PyIntegerSet::operator==(const PyIntegerSet &other) {
-  return mlirIntegerSetEqual(integerSet, other.integerSet);
-}
-
-py::object PyIntegerSet::getCapsule() {
-  return py::reinterpret_steal<py::object>(
-      mlirPythonIntegerSetToCapsule(*this));
-}
-
-PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
-  MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
-  if (mlirIntegerSetIsNull(rawIntegerSet))
-    throw py::error_already_set();
-  return PyIntegerSet(
-      PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
-      rawIntegerSet);
-}
-
-/// Attempts to populate `result` with the content of `list` casted to the
-/// appropriate type (Python and C types are provided as template arguments).
-/// Throws errors in case of failure, using "action" to describe what the caller
-/// was attempting to do.
-template <typename PyType, typename CType>
-static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
-                           StringRef action) {
-  result.reserve(py::len(list));
-  for (py::handle item : list) {
-    try {
-      result.push_back(item.cast<PyType>());
-    } catch (py::cast_error &err) {
-      std::string msg = (llvm::Twine("Invalid expression when ") + action +
-                         " (" + err.what() + ")")
-                            .str();
-      throw py::cast_error(msg);
-    } catch (py::reference_cast_error &err) {
-      std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
-                         action + " (" + err.what() + ")")
-                            .str();
-      throw py::cast_error(msg);
-    }
-  }
-}
-
-//------------------------------------------------------------------------------
-// Populates the pybind11 IR submodule.
-//------------------------------------------------------------------------------
-
-void mlir::python::populateIRSubmodule(py::module &m) {
-  //----------------------------------------------------------------------------
-  // Mapping of MlirContext
-  //----------------------------------------------------------------------------
-  py::class_<PyMlirContext>(m, "Context")
-      .def(py::init<>(&PyMlirContext::createNewContextForInit))
-      .def_static("_get_live_count", &PyMlirContext::getLiveCount)
-      .def("_get_context_again",
-           [](PyMlirContext &self) {
-             PyMlirContextRef ref = PyMlirContext::forContext(self.get());
-             return ref.releaseObject();
-           })
-      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
-      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
-                             &PyMlirContext::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
-      .def("__enter__", &PyMlirContext::contextEnter)
-      .def("__exit__", &PyMlirContext::contextExit)
-      .def_property_readonly_static(
-          "current",
-          [](py::object & /*class*/) {
-            auto *context = PyThreadContextEntry::getDefaultContext();
-            if (!context)
-              throw SetPyError(PyExc_ValueError, "No current Context");
-            return context;
-          },
-          "Gets the Context bound to the current thread or raises ValueError")
-      .def_property_readonly(
-          "dialects",
-          [](PyMlirContext &self) { return PyDialects(self.getRef()); },
-          "Gets a container for accessing dialects by name")
-      .def_property_readonly(
-          "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
-          "Alias for 'dialect'")
-      .def(
-          "get_dialect_descriptor",
-          [=](PyMlirContext &self, std::string &name) {
-            MlirDialect dialect = mlirContextGetOrLoadDialect(
-                self.get(), {name.data(), name.size()});
-            if (mlirDialectIsNull(dialect)) {
-              throw SetPyError(PyExc_ValueError,
-                               Twine("Dialect '") + name + "' not found");
-            }
-            return PyDialectDescriptor(self.getRef(), dialect);
-          },
-          "Gets or loads a dialect by name, returning its descriptor object")
-      .def_property(
-          "allow_unregistered_dialects",
-          [](PyMlirContext &self) -> bool {
-            return mlirContextGetAllowUnregisteredDialects(self.get());
-          },
-          [](PyMlirContext &self, bool value) {
-            mlirContextSetAllowUnregisteredDialects(self.get(), value);
-          });
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyDialectDescriptor
-  //----------------------------------------------------------------------------
-  py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
-      .def_property_readonly("namespace",
-                             [](PyDialectDescriptor &self) {
-                               MlirStringRef ns =
-                                   mlirDialectGetNamespace(self.get());
-                               return py::str(ns.data, ns.length);
-                             })
-      .def("__repr__", [](PyDialectDescriptor &self) {
-        MlirStringRef ns = mlirDialectGetNamespace(self.get());
-        std::string repr("<DialectDescriptor ");
-        repr.append(ns.data, ns.length);
-        repr.append(">");
-        return repr;
-      });
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyDialects
-  //----------------------------------------------------------------------------
-  py::class_<PyDialects>(m, "Dialects")
-      .def("__getitem__",
-           [=](PyDialects &self, std::string keyName) {
-             MlirDialect dialect =
-                 self.getDialectForKey(keyName, /*attrError=*/false);
-             py::object descriptor =
-                 py::cast(PyDialectDescriptor{self.getContext(), dialect});
-             return createCustomDialectWrapper(keyName, std::move(descriptor));
-           })
-      .def("__getattr__", [=](PyDialects &self, std::string attrName) {
-        MlirDialect dialect =
-            self.getDialectForKey(attrName, /*attrError=*/true);
-        py::object descriptor =
-            py::cast(PyDialectDescriptor{self.getContext(), dialect});
-        return createCustomDialectWrapper(attrName, std::move(descriptor));
-      });
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyDialect
-  //----------------------------------------------------------------------------
-  py::class_<PyDialect>(m, "Dialect")
-      .def(py::init<py::object>(), "descriptor")
-      .def_property_readonly(
-          "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
-      .def("__repr__", [](py::object self) {
-        auto clazz = self.attr("__class__");
-        return py::str("<Dialect ") +
-               self.attr("descriptor").attr("namespace") + py::str(" (class ") +
-               clazz.attr("__module__") + py::str(".") +
-               clazz.attr("__name__") + py::str(")>");
-      });
-
-  //----------------------------------------------------------------------------
-  // Mapping of Location
-  //----------------------------------------------------------------------------
-  py::class_<PyLocation>(m, "Location")
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
-      .def("__enter__", &PyLocation::contextEnter)
-      .def("__exit__", &PyLocation::contextExit)
-      .def("__eq__",
-           [](PyLocation &self, PyLocation &other) -> bool {
-             return mlirLocationEqual(self, other);
-           })
-      .def("__eq__", [](PyLocation &self, py::object other) { return false; })
-      .def_property_readonly_static(
-          "current",
-          [](py::object & /*class*/) {
-            auto *loc = PyThreadContextEntry::getDefaultLocation();
-            if (!loc)
-              throw SetPyError(PyExc_ValueError, "No current Location");
-            return loc;
-          },
-          "Gets the Location bound to the current thread or raises ValueError")
-      .def_static(
-          "unknown",
-          [](DefaultingPyMlirContext context) {
-            return PyLocation(context->getRef(),
-                              mlirLocationUnknownGet(context->get()));
-          },
-          py::arg("context") = py::none(),
-          "Gets a Location representing an unknown location")
-      .def_static(
-          "file",
-          [](std::string filename, int line, int col,
-             DefaultingPyMlirContext context) {
-            return PyLocation(
-                context->getRef(),
-                mlirLocationFileLineColGet(
-                    context->get(), toMlirStringRef(filename), line, col));
-          },
-          py::arg("filename"), py::arg("line"), py::arg("col"),
-          py::arg("context") = py::none(), kContextGetFileLocationDocstring)
-      .def_property_readonly(
-          "context",
-          [](PyLocation &self) { return self.getContext().getObject(); },
-          "Context that owns the Location")
-      .def("__repr__", [](PyLocation &self) {
-        PyPrintAccumulator printAccum;
-        mlirLocationPrint(self, printAccum.getCallback(),
-                          printAccum.getUserData());
-        return printAccum.join();
-      });
+  //----------------------------------------------------------------------------
+  // Mapping of Location
+  //----------------------------------------------------------------------------
+  py::class_<PyLocation>(m, "Location")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+      .def("__enter__", &PyLocation::contextEnter)
+      .def("__exit__", &PyLocation::contextExit)
+      .def("__eq__",
+           [](PyLocation &self, PyLocation &other) -> bool {
+             return mlirLocationEqual(self, other);
+           })
+      .def("__eq__", [](PyLocation &self, py::object other) { return false; })
+      .def_property_readonly_static(
+          "current",
+          [](py::object & /*class*/) {
+            auto *loc = PyThreadContextEntry::getDefaultLocation();
+            if (!loc)
+              throw SetPyError(PyExc_ValueError, "No current Location");
+            return loc;
+          },
+          "Gets the Location bound to the current thread or raises ValueError")
+      .def_static(
+          "unknown",
+          [](DefaultingPyMlirContext context) {
+            return PyLocation(context->getRef(),
+                              mlirLocationUnknownGet(context->get()));
+          },
+          py::arg("context") = py::none(),
+          "Gets a Location representing an unknown location")
+      .def_static(
+          "file",
+          [](std::string filename, int line, int col,
+             DefaultingPyMlirContext context) {
+            return PyLocation(
+                context->getRef(),
+                mlirLocationFileLineColGet(
+                    context->get(), toMlirStringRef(filename), line, col));
+          },
+          py::arg("filename"), py::arg("line"), py::arg("col"),
+          py::arg("context") = py::none(), kContextGetFileLocationDocstring)
+      .def_property_readonly(
+          "context",
+          [](PyLocation &self) { return self.getContext().getObject(); },
+          "Context that owns the Location")
+      .def("__repr__", [](PyLocation &self) {
+        PyPrintAccumulator printAccum;
+        mlirLocationPrint(self, printAccum.getCallback(),
+                          printAccum.getUserData());
+        return printAccum.join();
+      });
 
   //----------------------------------------------------------------------------
   // Mapping of Module
@@ -4022,22 +2259,6 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           py::keep_alive<0, 1>(),
           "The underlying generic attribute of the NamedAttribute binding");
 
-  // Builtin attribute bindings.
-  PyAffineMapAttribute::bind(m);
-  PyArrayAttribute::bind(m);
-  PyArrayAttribute::PyArrayAttributeIterator::bind(m);
-  PyBoolAttribute::bind(m);
-  PyDenseElementsAttribute::bind(m);
-  PyDenseFPElementsAttribute::bind(m);
-  PyDenseIntElementsAttribute::bind(m);
-  PyDictAttribute::bind(m);
-  PyFlatSymbolRefAttribute::bind(m);
-  PyFloatAttribute::bind(m);
-  PyIntegerAttribute::bind(m);
-  PyStringAttribute::bind(m);
-  PyTypeAttribute::bind(m);
-  PyUnitAttribute::bind(m);
-
   //----------------------------------------------------------------------------
   // Mapping of PyType.
   //----------------------------------------------------------------------------
@@ -4088,25 +2309,6 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         return printAccum.join();
       });
 
-  // Builtin type bindings.
-  PyIntegerType::bind(m);
-  PyIndexType::bind(m);
-  PyBF16Type::bind(m);
-  PyF16Type::bind(m);
-  PyF32Type::bind(m);
-  PyF64Type::bind(m);
-  PyNoneType::bind(m);
-  PyComplexType::bind(m);
-  PyShapedType::bind(m);
-  PyVectorType::bind(m);
-  PyRankedTensorType::bind(m);
-  PyUnrankedTensorType::bind(m);
-  PyMemRefType::bind(m);
-  PyMemRefLayoutMapList::bind(m);
-  PyUnrankedMemRefType::bind(m);
-  PyTupleType::bind(m);
-  PyFunctionType::bind(m);
-
   //----------------------------------------------------------------------------
   // Mapping of Value.
   //----------------------------------------------------------------------------
@@ -4152,359 +2354,4 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyOpResultList::bind(m);
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyAffineExpr and derived classes.
-  //----------------------------------------------------------------------------
-  py::class_<PyAffineExpr>(m, "AffineExpr")
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
-                             &PyAffineExpr::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
-      .def("__add__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineAddExpr::get(self, other);
-           })
-      .def("__mul__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineMulExpr::get(self, other);
-           })
-      .def("__mod__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             return PyAffineModExpr::get(self, other);
-           })
-      .def("__sub__",
-           [](PyAffineExpr &self, PyAffineExpr &other) {
-             auto negOne =
-                 PyAffineConstantExpr::get(-1, *self.getContext().get());
-             return PyAffineAddExpr::get(self,
-                                         PyAffineMulExpr::get(negOne, other));
-           })
-      .def("__eq__", [](PyAffineExpr &self,
-                        PyAffineExpr &other) { return self == other; })
-      .def("__eq__",
-           [](PyAffineExpr &self, py::object &other) { return false; })
-      .def("__str__",
-           [](PyAffineExpr &self) {
-             PyPrintAccumulator printAccum;
-             mlirAffineExprPrint(self, printAccum.getCallback(),
-                                 printAccum.getUserData());
-             return printAccum.join();
-           })
-      .def("__repr__",
-           [](PyAffineExpr &self) {
-             PyPrintAccumulator printAccum;
-             printAccum.parts.append("AffineExpr(");
-             mlirAffineExprPrint(self, printAccum.getCallback(),
-                                 printAccum.getUserData());
-             printAccum.parts.append(")");
-             return printAccum.join();
-           })
-      .def_property_readonly(
-          "context",
-          [](PyAffineExpr &self) { return self.getContext().getObject(); })
-      .def_static(
-          "get_add", &PyAffineAddExpr::get,
-          "Gets an affine expression containing a sum of two expressions.")
-      .def_static(
-          "get_mul", &PyAffineMulExpr::get,
-          "Gets an affine expression containing a product of two expressions.")
-      .def_static("get_mod", &PyAffineModExpr::get,
-                  "Gets an affine expression containing the modulo of dividing "
-                  "one expression by another.")
-      .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
-                  "Gets an affine expression containing the rounded-down "
-                  "result of dividing one expression by another.")
-      .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
-                  "Gets an affine expression containing the rounded-up result "
-                  "of dividing one expression by another.")
-      .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
-                  py::arg("context") = py::none(),
-                  "Gets a constant affine expression with the given value.")
-      .def_static(
-          "get_dim", &PyAffineDimExpr::get, py::arg("position"),
-          py::arg("context") = py::none(),
-          "Gets an affine expression of a dimension at the given position.")
-      .def_static(
-          "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
-          py::arg("context") = py::none(),
-          "Gets an affine expression of a symbol at the given position.")
-      .def(
-          "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
-          kDumpDocstring);
-  PyAffineConstantExpr::bind(m);
-  PyAffineDimExpr::bind(m);
-  PyAffineSymbolExpr::bind(m);
-  PyAffineBinaryExpr::bind(m);
-  PyAffineAddExpr::bind(m);
-  PyAffineMulExpr::bind(m);
-  PyAffineModExpr::bind(m);
-  PyAffineFloorDivExpr::bind(m);
-  PyAffineCeilDivExpr::bind(m);
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyAffineMap.
-  //----------------------------------------------------------------------------
-  py::class_<PyAffineMap>(m, "AffineMap")
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
-                             &PyAffineMap::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
-      .def("__eq__",
-           [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
-      .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
-      .def("__str__",
-           [](PyAffineMap &self) {
-             PyPrintAccumulator printAccum;
-             mlirAffineMapPrint(self, printAccum.getCallback(),
-                                printAccum.getUserData());
-             return printAccum.join();
-           })
-      .def("__repr__",
-           [](PyAffineMap &self) {
-             PyPrintAccumulator printAccum;
-             printAccum.parts.append("AffineMap(");
-             mlirAffineMapPrint(self, printAccum.getCallback(),
-                                printAccum.getUserData());
-             printAccum.parts.append(")");
-             return printAccum.join();
-           })
-      .def_property_readonly(
-          "context",
-          [](PyAffineMap &self) { return self.getContext().getObject(); },
-          "Context that owns the Affine Map")
-      .def(
-          "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
-          kDumpDocstring)
-      .def_static(
-          "get",
-          [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
-             DefaultingPyMlirContext context) {
-            SmallVector<MlirAffineExpr> affineExprs;
-            pyListToVector<PyAffineExpr, MlirAffineExpr>(
-                exprs, affineExprs, "attempting to create an AffineMap");
-            MlirAffineMap map =
-                mlirAffineMapGet(context->get(), dimCount, symbolCount,
-                                 affineExprs.size(), affineExprs.data());
-            return PyAffineMap(context->getRef(), map);
-          },
-          py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
-          py::arg("context") = py::none(),
-          "Gets a map with the given expressions as results.")
-      .def_static(
-          "get_constant",
-          [](intptr_t value, DefaultingPyMlirContext context) {
-            MlirAffineMap affineMap =
-                mlirAffineMapConstantGet(context->get(), value);
-            return PyAffineMap(context->getRef(), affineMap);
-          },
-          py::arg("value"), py::arg("context") = py::none(),
-          "Gets an affine map with a single constant result")
-      .def_static(
-          "get_empty",
-          [](DefaultingPyMlirContext context) {
-            MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
-            return PyAffineMap(context->getRef(), affineMap);
-          },
-          py::arg("context") = py::none(), "Gets an empty affine map.")
-      .def_static(
-          "get_identity",
-          [](intptr_t nDims, DefaultingPyMlirContext context) {
-            MlirAffineMap affineMap =
-                mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
-            return PyAffineMap(context->getRef(), affineMap);
-          },
-          py::arg("n_dims"), py::arg("context") = py::none(),
-          "Gets an identity map with the given number of dimensions.")
-      .def_static(
-          "get_minor_identity",
-          [](intptr_t nDims, intptr_t nResults,
-             DefaultingPyMlirContext context) {
-            MlirAffineMap affineMap =
-                mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
-            return PyAffineMap(context->getRef(), affineMap);
-          },
-          py::arg("n_dims"), py::arg("n_results"),
-          py::arg("context") = py::none(),
-          "Gets a minor identity map with the given number of dimensions and "
-          "results.")
-      .def_static(
-          "get_permutation",
-          [](std::vector<unsigned> permutation,
-             DefaultingPyMlirContext context) {
-            if (!isPermutation(permutation))
-              throw py::cast_error("Invalid permutation when attempting to "
-                                   "create an AffineMap");
-            MlirAffineMap affineMap = mlirAffineMapPermutationGet(
-                context->get(), permutation.size(), permutation.data());
-            return PyAffineMap(context->getRef(), affineMap);
-          },
-          py::arg("permutation"), py::arg("context") = py::none(),
-          "Gets an affine map that permutes its inputs.")
-      .def("get_submap",
-           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
-             intptr_t numResults = mlirAffineMapGetNumResults(self);
-             for (intptr_t pos : resultPos) {
-               if (pos < 0 || pos >= numResults)
-                 throw py::value_error("result position out of bounds");
-             }
-             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
-                 self, resultPos.size(), resultPos.data());
-             return PyAffineMap(self.getContext(), affineMap);
-           })
-      .def("get_major_submap",
-           [](PyAffineMap &self, intptr_t nResults) {
-             if (nResults >= mlirAffineMapGetNumResults(self))
-               throw py::value_error("number of results out of bounds");
-             MlirAffineMap affineMap =
-                 mlirAffineMapGetMajorSubMap(self, nResults);
-             return PyAffineMap(self.getContext(), affineMap);
-           })
-      .def("get_minor_submap",
-           [](PyAffineMap &self, intptr_t nResults) {
-             if (nResults >= mlirAffineMapGetNumResults(self))
-               throw py::value_error("number of results out of bounds");
-             MlirAffineMap affineMap =
-                 mlirAffineMapGetMinorSubMap(self, nResults);
-             return PyAffineMap(self.getContext(), affineMap);
-           })
-      .def_property_readonly(
-          "is_permutation",
-          [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
-      .def_property_readonly("is_projected_permutation",
-                             [](PyAffineMap &self) {
-                               return mlirAffineMapIsProjectedPermutation(self);
-                             })
-      .def_property_readonly(
-          "n_dims",
-          [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
-      .def_property_readonly(
-          "n_inputs",
-          [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
-      .def_property_readonly(
-          "n_symbols",
-          [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
-      .def_property_readonly("results", [](PyAffineMap &self) {
-        return PyAffineMapExprList(self);
-      });
-  PyAffineMapExprList::bind(m);
-
-  //----------------------------------------------------------------------------
-  // Mapping of PyIntegerSet.
-  //----------------------------------------------------------------------------
-  py::class_<PyIntegerSet>(m, "IntegerSet")
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
-                             &PyIntegerSet::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
-      .def("__eq__", [](PyIntegerSet &self,
-                        PyIntegerSet &other) { return self == other; })
-      .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
-      .def("__str__",
-           [](PyIntegerSet &self) {
-             PyPrintAccumulator printAccum;
-             mlirIntegerSetPrint(self, printAccum.getCallback(),
-                                 printAccum.getUserData());
-             return printAccum.join();
-           })
-      .def("__repr__",
-           [](PyIntegerSet &self) {
-             PyPrintAccumulator printAccum;
-             printAccum.parts.append("IntegerSet(");
-             mlirIntegerSetPrint(self, printAccum.getCallback(),
-                                 printAccum.getUserData());
-             printAccum.parts.append(")");
-             return printAccum.join();
-           })
-      .def_property_readonly(
-          "context",
-          [](PyIntegerSet &self) { return self.getContext().getObject(); })
-      .def(
-          "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
-          kDumpDocstring)
-      .def_static(
-          "get",
-          [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
-             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
-            if (exprs.size() != eqFlags.size())
-              throw py::value_error(
-                  "Expected the number of constraints to match "
-                  "that of equality flags");
-            if (exprs.empty())
-              throw py::value_error("Expected non-empty list of constraints");
-
-            // Copy over to a SmallVector because std::vector has a
-            // specialization for booleans that packs data and does not
-            // expose a `bool *`.
-            SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
-
-            SmallVector<MlirAffineExpr> affineExprs;
-            pyListToVector<PyAffineExpr>(exprs, affineExprs,
-                                         "attempting to create an IntegerSet");
-            MlirIntegerSet set = mlirIntegerSetGet(
-                context->get(), numDims, numSymbols, exprs.size(),
-                affineExprs.data(), flags.data());
-            return PyIntegerSet(context->getRef(), set);
-          },
-          py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
-          py::arg("eq_flags"), py::arg("context") = py::none())
-      .def_static(
-          "get_empty",
-          [](intptr_t numDims, intptr_t numSymbols,
-             DefaultingPyMlirContext context) {
-            MlirIntegerSet set =
-                mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
-            return PyIntegerSet(context->getRef(), set);
-          },
-          py::arg("num_dims"), py::arg("num_symbols"),
-          py::arg("context") = py::none())
-      .def("get_replaced",
-           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
-              intptr_t numResultDims, intptr_t numResultSymbols) {
-             if (static_cast<intptr_t>(dimExprs.size()) !=
-                 mlirIntegerSetGetNumDims(self))
-               throw py::value_error(
-                   "Expected the number of dimension replacement expressions "
-                   "to match that of dimensions");
-             if (static_cast<intptr_t>(symbolExprs.size()) !=
-                 mlirIntegerSetGetNumSymbols(self))
-               throw py::value_error(
-                   "Expected the number of symbol replacement expressions "
-                   "to match that of symbols");
-
-             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
-             pyListToVector<PyAffineExpr>(
-                 dimExprs, dimAffineExprs,
-                 "attempting to create an IntegerSet by replacing dimensions");
-             pyListToVector<PyAffineExpr>(
-                 symbolExprs, symbolAffineExprs,
-                 "attempting to create an IntegerSet by replacing symbols");
-             MlirIntegerSet set = mlirIntegerSetReplaceGet(
-                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
-                 numResultDims, numResultSymbols);
-             return PyIntegerSet(self.getContext(), set);
-           })
-      .def_property_readonly("is_canonical_empty",
-                             [](PyIntegerSet &self) {
-                               return mlirIntegerSetIsCanonicalEmpty(self);
-                             })
-      .def_property_readonly(
-          "n_dims",
-          [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
-      .def_property_readonly(
-          "n_symbols",
-          [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
-      .def_property_readonly(
-          "n_inputs",
-          [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
-      .def_property_readonly("n_equalities",
-                             [](PyIntegerSet &self) {
-                               return mlirIntegerSetGetNumEqualities(self);
-                             })
-      .def_property_readonly("n_inequalities",
-                             [](PyIntegerSet &self) {
-                               return mlirIntegerSetGetNumInequalities(self);
-                             })
-      .def_property_readonly("constraints", [](PyIntegerSet &self) {
-        return PyIntegerSetConstraintList(self);
-      });
-  PyIntegerSetConstraint::bind(m);
-  PyIntegerSetConstraintList::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModule.h
similarity index 99%
rename from mlir/lib/Bindings/Python/IRModules.h
rename to mlir/lib/Bindings/Python/IRModule.h
index 8140d704300d..5c710abe789a 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -747,7 +747,10 @@ class PyIntegerSet : public BaseContextObject {
   MlirIntegerSet integerSet;
 };
 
-void populateIRSubmodule(pybind11::module &m);
+void populateIRAffine(pybind11::module &m);
+void populateIRAttributes(pybind11::module &m);
+void populateIRCore(pybind11::module &m);
+void populateIRTypes(pybind11::module &m);
 
 } // namespace python
 } // namespace mlir

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
new file mode 100644
index 000000000000..96f6bf6666c9
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -0,0 +1,678 @@
+//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/BuiltinTypes.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::Twine;
+
+namespace {
+
+/// Checks whether the given type is an integer or float type.
+static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+         mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = py::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirType);
+
+  PyConcreteType() = default;
+  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+      : BaseTy(std::move(contextRef), t) {}
+  PyConcreteType(PyType &orig)
+      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+  static MlirType castFrom(PyType &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+    cls.def_static("isinstance", [](PyType &otherType) -> bool {
+      return DerivedTy::isaFunction(otherType);
+    });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_signless",
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
+        },
+        py::arg("width"), py::arg("context") = py::none(),
+        "Create a signless integer type");
+    c.def_static(
+        "get_signed",
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
+        },
+        py::arg("width"), py::arg("context") = py::none(),
+        "Create a signed integer type");
+    c.def_static(
+        "get_unsigned",
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
+        },
+        py::arg("width"), py::arg("context") = py::none(),
+        "Create an unsigned integer type");
+    c.def_property_readonly(
+        "width",
+        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+        "Returns the width of the integer type");
+    c.def_property_readonly(
+        "is_signless",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsSignless(self);
+        },
+        "Returns whether this is a signless integer");
+    c.def_property_readonly(
+        "is_signed",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsSigned(self);
+        },
+        "Returns whether this is a signed integer");
+    c.def_property_readonly(
+        "is_unsigned",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsUnsigned(self);
+        },
+        "Returns whether this is an unsigned integer");
+  }
+};
+
+/// Index Type subclass - IndexType.
+class PyIndexType : public PyConcreteType<PyIndexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+  static constexpr const char *pyClassName = "IndexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirIndexTypeGet(context->get());
+          return PyIndexType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a index type.");
+  }
+};
+
+/// Floating Point Type subclass - BF16Type.
+class PyBF16Type : public PyConcreteType<PyBF16Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+  static constexpr const char *pyClassName = "BF16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirBF16TypeGet(context->get());
+          return PyBF16Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a bf16 type.");
+  }
+};
+
+/// Floating Point Type subclass - F16Type.
+class PyF16Type : public PyConcreteType<PyF16Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+  static constexpr const char *pyClassName = "F16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF16TypeGet(context->get());
+          return PyF16Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a f16 type.");
+  }
+};
+
+/// Floating Point Type subclass - F32Type.
+class PyF32Type : public PyConcreteType<PyF32Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+  static constexpr const char *pyClassName = "F32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF32TypeGet(context->get());
+          return PyF32Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a f32 type.");
+  }
+};
+
+/// Floating Point Type subclass - F64Type.
+class PyF64Type : public PyConcreteType<PyF64Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+  static constexpr const char *pyClassName = "F64Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF64TypeGet(context->get());
+          return PyF64Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a f64 type.");
+  }
+};
+
+/// None Type subclass - NoneType.
+class PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+  static constexpr const char *pyClassName = "NoneType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirNoneTypeGet(context->get());
+          return PyNoneType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a none type.");
+  }
+};
+
+/// Complex Type subclass - ComplexType.
+class PyComplexType : public PyConcreteType<PyComplexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+  static constexpr const char *pyClassName = "ComplexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &elementType) {
+          // The element must be a floating point or integer scalar type.
+          if (mlirTypeIsAIntegerOrFloat(elementType)) {
+            MlirType t = mlirComplexTypeGet(elementType);
+            return PyComplexType(elementType.getContext(), t);
+          }
+          throw SetPyError(
+              PyExc_ValueError,
+              Twine("invalid '") +
+                  py::repr(py::cast(elementType)).cast<std::string>() +
+                  "' and expected floating point or integer type.");
+        },
+        "Create a complex type");
+    c.def_property_readonly(
+        "element_type",
+        [](PyComplexType &self) -> PyType {
+          MlirType t = mlirComplexTypeGetElementType(self);
+          return PyType(self.getContext(), t);
+        },
+        "Returns element type.");
+  }
+};
+
+class PyShapedType : public PyConcreteType<PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
+  static constexpr const char *pyClassName = "ShapedType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly(
+        "element_type",
+        [](PyShapedType &self) {
+          MlirType t = mlirShapedTypeGetElementType(self);
+          return PyType(self.getContext(), t);
+        },
+        "Returns the element type of the shaped type.");
+    c.def_property_readonly(
+        "has_rank",
+        [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
+        "Returns whether the given shaped type is ranked.");
+    c.def_property_readonly(
+        "rank",
+        [](PyShapedType &self) {
+          self.requireHasRank();
+          return mlirShapedTypeGetRank(self);
+        },
+        "Returns the rank of the given ranked shaped type.");
+    c.def_property_readonly(
+        "has_static_shape",
+        [](PyShapedType &self) -> bool {
+          return mlirShapedTypeHasStaticShape(self);
+        },
+        "Returns whether the given shaped type has a static shape.");
+    c.def(
+        "is_dynamic_dim",
+        [](PyShapedType &self, intptr_t dim) -> bool {
+          self.requireHasRank();
+          return mlirShapedTypeIsDynamicDim(self, dim);
+        },
+        "Returns whether the dim-th dimension of the given shaped type is "
+        "dynamic.");
+    c.def(
+        "get_dim_size",
+        [](PyShapedType &self, intptr_t dim) {
+          self.requireHasRank();
+          return mlirShapedTypeGetDimSize(self, dim);
+        },
+        "Returns the dim-th dimension of the given ranked shaped type.");
+    c.def_static(
+        "is_dynamic_size",
+        [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
+        "Returns whether the given dimension size indicates a dynamic "
+        "dimension.");
+    c.def(
+        "is_dynamic_stride_or_offset",
+        [](PyShapedType &self, int64_t val) -> bool {
+          self.requireHasRank();
+          return mlirShapedTypeIsDynamicStrideOrOffset(val);
+        },
+        "Returns whether the given value is used as a placeholder for dynamic "
+        "strides and offsets in shaped types.");
+  }
+
+private:
+  void requireHasRank() {
+    if (!mlirShapedTypeHasRank(*this)) {
+      throw SetPyError(
+          PyExc_ValueError,
+          "calling this method requires that the type has a rank.");
+    }
+  }
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr const char *pyClassName = "VectorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<int64_t> shape, PyType &elementType,
+           DefaultingPyLocation loc) {
+          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                                elementType);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point or integer type.");
+          }
+          return PyVectorType(elementType.getContext(), t);
+        },
+        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
+        "Create a vector type");
+  }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class PyRankedTensorType
+    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr const char *pyClassName = "RankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<int64_t> shape, PyType &elementType,
+           DefaultingPyLocation loc) {
+          MlirType t = mlirRankedTensorTypeGetChecked(
+              loc, shape.size(), shape.data(), elementType);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point, integer, vector or "
+                    "complex "
+                    "type.");
+          }
+          return PyRankedTensorType(elementType.getContext(), t);
+        },
+        py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
+        "Create a ranked tensor type");
+  }
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class PyUnrankedTensorType
+    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+  static constexpr const char *pyClassName = "UnrankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &elementType, DefaultingPyLocation loc) {
+          MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point, integer, vector or "
+                    "complex "
+                    "type.");
+          }
+          return PyUnrankedTensorType(elementType.getContext(), t);
+        },
+        py::arg("element_type"), py::arg("loc") = py::none(),
+        "Create a unranked tensor type");
+  }
+};
+
+class PyMemRefLayoutMapList;
+
+/// Ranked MemRef Type subclass - MemRefType.
+class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr const char *pyClassName = "MemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  PyMemRefLayoutMapList getLayout();
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+         "get",
+         [](std::vector<int64_t> shape, PyType &elementType,
+            std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
+            DefaultingPyLocation loc) {
+           SmallVector<MlirAffineMap> maps;
+           maps.reserve(layout.size());
+           for (PyAffineMap &map : layout)
+             maps.push_back(map);
+
+           MlirAttribute memSpaceAttr = {};
+           if (memorySpace)
+             memSpaceAttr = *memorySpace;
+
+           MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+                                                 shape.data(), maps.size(),
+                                                 maps.data(), memSpaceAttr);
+           // TODO: Rework error reporting once diagnostic engine is exposed
+           // in C API.
+           if (mlirTypeIsNull(t)) {
+             throw SetPyError(
+                 PyExc_ValueError,
+                 Twine("invalid '") +
+                     py::repr(py::cast(elementType)).cast<std::string>() +
+                     "' and expected floating point, integer, vector or "
+                     "complex "
+                     "type.");
+           }
+           return PyMemRefType(elementType.getContext(), t);
+         },
+         py::arg("shape"), py::arg("element_type"),
+         py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
+         py::arg("loc") = py::none(), "Create a memref type")
+        .def_property_readonly("layout", &PyMemRefType::getLayout,
+                               "The list of layout maps of the MemRef type.")
+        .def_property_readonly(
+            "memory_space",
+            [](PyMemRefType &self) -> PyAttribute {
+              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+              return PyAttribute(self.getContext(), a);
+            },
+            "Returns the memory space of the given MemRef type.");
+  }
+};
+
+/// A list of affine layout maps in a memref type. Internally, these are stored
+/// as consecutive elements, random access is cheap. Both the type and the maps
+/// are owned by the context, no need to worry about lifetime extension.
+class PyMemRefLayoutMapList
+    : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
+public:
+  static constexpr const char *pyClassName = "MemRefLayoutMapList";
+
+  PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
+                        intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
+                  step),
+        memref(type) {}
+
+  intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
+
+  PyAffineMap getElement(intptr_t index) {
+    return PyAffineMap(memref.getContext(),
+                       mlirMemRefTypeGetAffineMap(memref, index));
+  }
+
+  PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
+                              intptr_t step) {
+    return PyMemRefLayoutMapList(memref, startIndex, length, step);
+  }
+
+private:
+  PyMemRefType memref;
+};
+
+PyMemRefLayoutMapList PyMemRefType::getLayout() {
+  return PyMemRefLayoutMapList(*this);
+}
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class PyUnrankedMemRefType
+    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+  static constexpr const char *pyClassName = "UnrankedMemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+         "get",
+         [](PyType &elementType, PyAttribute *memorySpace,
+            DefaultingPyLocation loc) {
+           MlirAttribute memSpaceAttr = {};
+           if (memorySpace)
+             memSpaceAttr = *memorySpace;
+
+           MlirType t =
+               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+           // TODO: Rework error reporting once diagnostic engine is exposed
+           // in C API.
+           if (mlirTypeIsNull(t)) {
+             throw SetPyError(
+                 PyExc_ValueError,
+                 Twine("invalid '") +
+                     py::repr(py::cast(elementType)).cast<std::string>() +
+                     "' and expected floating point, integer, vector or "
+                     "complex "
+                     "type.");
+           }
+           return PyUnrankedMemRefType(elementType.getContext(), t);
+         },
+         py::arg("element_type"), py::arg("memory_space"),
+         py::arg("loc") = py::none(), "Create a unranked memref type")
+        .def_property_readonly(
+            "memory_space",
+            [](PyUnrankedMemRefType &self) -> PyAttribute {
+              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+              return PyAttribute(self.getContext(), a);
+            },
+            "Returns the memory space of the given Unranked MemRef type.");
+  }
+};
+
+/// Tuple Type subclass - TupleType.
+class PyTupleType : public PyConcreteType<PyTupleType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+  static constexpr const char *pyClassName = "TupleType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_tuple",
+        [](py::list elementList, DefaultingPyMlirContext context) {
+          intptr_t num = py::len(elementList);
+          // Mapping py::list to SmallVector.
+          SmallVector<MlirType, 4> elements;
+          for (auto element : elementList)
+            elements.push_back(element.cast<PyType>());
+          MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+          return PyTupleType(context->getRef(), t);
+        },
+        py::arg("elements"), py::arg("context") = py::none(),
+        "Create a tuple type");
+    c.def(
+        "get_type",
+        [](PyTupleType &self, intptr_t pos) -> PyType {
+          MlirType t = mlirTupleTypeGetType(self, pos);
+          return PyType(self.getContext(), t);
+        },
+        "Returns the pos-th type in the tuple type.");
+    c.def_property_readonly(
+        "num_types",
+        [](PyTupleType &self) -> intptr_t {
+          return mlirTupleTypeGetNumTypes(self);
+        },
+        "Returns the number of types contained in a tuple.");
+  }
+};
+
+/// Function type.
+class PyFunctionType : public PyConcreteType<PyFunctionType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+  static constexpr const char *pyClassName = "FunctionType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<PyType> inputs, std::vector<PyType> results,
+           DefaultingPyMlirContext context) {
+          SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
+          SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
+          MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
+                                           inputsRaw.data(), resultsRaw.size(),
+                                           resultsRaw.data());
+          return PyFunctionType(context->getRef(), t);
+        },
+        py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
+        "Gets a FunctionType from a list of input and result types");
+    c.def_property_readonly(
+        "inputs",
+        [](PyFunctionType &self) {
+          MlirType t = self;
+          auto contextRef = self.getContext();
+          py::list types;
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+               ++i) {
+            types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+          }
+          return types;
+        },
+        "Returns the list of input types in the FunctionType.");
+    c.def_property_readonly(
+        "results",
+        [](PyFunctionType &self) {
+          auto contextRef = self.getContext();
+          py::list types;
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+               ++i) {
+            types.append(
+                PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
+          }
+          return types;
+        },
+        "Returns the list of result types in the FunctionType.");
+  }
+};
+
+} // namespace
+
+void mlir::python::populateIRTypes(py::module &m) {
+  PyIntegerType::bind(m);
+  PyIndexType::bind(m);
+  PyBF16Type::bind(m);
+  PyF16Type::bind(m);
+  PyF32Type::bind(m);
+  PyF64Type::bind(m);
+  PyNoneType::bind(m);
+  PyComplexType::bind(m);
+  PyShapedType::bind(m);
+  PyVectorType::bind(m);
+  PyRankedTensorType::bind(m);
+  PyUnrankedTensorType::bind(m);
+  PyMemRefType::bind(m);
+  PyMemRefLayoutMapList::bind(m);
+  PyUnrankedMemRefType::bind(m);
+  PyTupleType::bind(m);
+  PyFunctionType::bind(m);
+}

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 9bfe8b09f6db..5fe0401afaeb 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,7 +12,7 @@
 
 #include "ExecutionEngine.h"
 #include "Globals.h"
-#include "IRModules.h"
+#include "IRModule.h"
 #include "Pass.h"
 
 namespace py = pybind11;
@@ -211,7 +211,10 @@ PYBIND11_MODULE(_mlir, m) {
 
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
-  populateIRSubmodule(irModule);
+  populateIRCore(irModule);
+  populateIRAffine(irModule);
+  populateIRAttributes(irModule);
+  populateIRTypes(irModule);
 
   // Define and populate PassManager submodule.
   auto passModule =

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index dd57647f0327..0e2f5bafb465 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,7 +8,7 @@
 
 #include "Pass.h"
 
-#include "IRModules.h"
+#include "IRModule.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Pass.h"
 


        


More information about the Mlir-commits mailing list