[Mlir-commits] [mlir] 74628c4 - [mlir] Add Python bindings for AffineExpr

Alex Zinenko llvmlistbot at llvm.org
Mon Jan 11 10:57:25 PST 2021


Author: Alex Zinenko
Date: 2021-01-11T19:57:13+01:00
New Revision: 74628c43053b482f35f0f1e6b4eac743fbe425e5

URL: https://github.com/llvm/llvm-project/commit/74628c43053b482f35f0f1e6b4eac743fbe425e5
DIFF: https://github.com/llvm/llvm-project/commit/74628c43053b482f35f0f1e6b4eac743fbe425e5.diff

LOG: [mlir] Add Python bindings for AffineExpr

This adds the Python bindings for AffineExpr and a couple of utility functions
to the C API. AffineExpr is a top-level context-owned object and is modeled
similarly to attributes and types. It is required, e.g., to build layout maps
of the built-in memref type.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94225

Added: 
    mlir/test/Bindings/Python/ir_affine_expr.py

Modified: 
    mlir/include/mlir-c/AffineExpr.h
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/AffineExpr.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 93b8e832b44f..2eb8ae03e03d 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -45,6 +45,16 @@ DEFINE_C_API_STRUCT(MlirAffineExpr, const void);
 MLIR_CAPI_EXPORTED MlirContext
 mlirAffineExprGetContext(MlirAffineExpr affineExpr);
 
+/// Returns `true` if the two affine expressions are equal.
+MLIR_CAPI_EXPORTED bool mlirAffineExprEqual(MlirAffineExpr lhs,
+                                            MlirAffineExpr rhs);
+
+/// Returns `true` if the given affine expression is a null expression. Note
+/// constant zero is not a null expression.
+inline bool mlirAffineExprIsNull(MlirAffineExpr affineExpr) {
+  return affineExpr.ptr == NULL;
+}
+
 /** Prints an affine expression by sending chunks of the string representation
  * and forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */
@@ -82,6 +92,9 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given affine expression is a dimension expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsADim(MlirAffineExpr affineExpr);
+
 /// Creates an affine dimension expression with 'position' in the context.
 MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx,
                                                        intptr_t position);
@@ -94,6 +107,9 @@ mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr);
 // Affine Symbol Expression.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given affine expression is a symbol expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr);
+
 /// Creates an affine symbol expression with 'position' in the context.
 MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx,
                                                           intptr_t position);
@@ -106,6 +122,9 @@ mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr);
 // Affine Constant Expression.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given affine expression is a constant expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr);
+
 /// Creates an affine constant expression with 'constant' in the context.
 MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx,
                                                             int64_t constant);
@@ -173,6 +192,9 @@ MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
 // Affine Binary Operation Expression.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given affine expression is binary.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr);
+
 /** Returns the left hand side affine expression of the given affine binary
  * operation expression. */
 MLIR_CAPI_EXPORTED MlirAffineExpr

diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index ae9d3a84a0a3..d1eda4202345 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -23,10 +23,12 @@
 
 #include <Python.h>
 
+#include "mlir-c/AffineExpr.h"
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Pass.h"
 
+#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
@@ -72,6 +74,25 @@
 extern "C" {
 #endif
 
+/** Creates a capsule object encapsulating the raw C-API MlirAffineExpr. The
+ * returned capsule does not extend or affect ownership of any Python objects
+ * that reference the expression in any way.
+ */
+static inline PyObject *mlirPythonAffineExprToCapsule(MlirAffineExpr expr) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(expr),
+                       MLIR_PYTHON_CAPSULE_AFFINE_EXPR, NULL);
+}
+
+/** Extracts an MlirAffineExpr from a capsule as produced from
+ * mlirPythonAffineExprToCapsule. If the capsule is not of the right type, then
+ * a null expression is returned (as checked via mlirAffineExprIsNull). In such
+ * a case, the Python APIs will have already set an error. */
+static inline MlirAffineExpr mlirPythonCapsuleToAffineExpr(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_AFFINE_EXPR);
+  MlirAffineExpr expr = {ptr};
+  return expr;
+}
+
 /** Creates a capsule object encapsulating the raw C-API MlirAttribute.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the attribute in any way.

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2dcee494715d..2d18a7a488e7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2710,6 +2710,238 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
 
 } // namespace
 
+//------------------------------------------------------------------------------
+// PyAffineExpr and subclasses.
+//------------------------------------------------------------------------------
+
+namespace {
+/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
+/// and should be castable from it. Intermediate hierarchy classes can be
+/// modeled by specifying BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAffineExpr>
+class PyConcreteAffineExpr : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = py::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirAffineExpr);
+
+  PyConcreteAffineExpr() = default;
+  PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+      : BaseTy(std::move(contextRef), affineExpr) {}
+  PyConcreteAffineExpr(PyAffineExpr &orig)
+      : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
+
+  static MlirAffineExpr castFrom(PyAffineExpr &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError,
+                       Twine("Cannot cast affine expression to ") +
+                           DerivedTy::pyClassName + " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(py::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(py::init<PyAffineExpr &>());
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
+  static constexpr const char *pyClassName = "AffineConstantExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineConstantExpr get(intptr_t value,
+                                  DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr =
+        mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
+    return PyAffineConstantExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+      return mlirAffineConstantExprGetValue(self);
+    });
+  }
+};
+
+class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
+  static constexpr const char *pyClassName = "AffineDimExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
+    return PyAffineDimExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+      return mlirAffineDimExprGetPosition(self);
+    });
+  }
+};
+
+class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
+  static constexpr const char *pyClassName = "AffineSymbolExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+    MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
+    return PyAffineSymbolExpr(context->getRef(), affineExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
+                 py::arg("context") = py::none());
+    c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+      return mlirAffineSymbolExprGetPosition(self);
+    });
+  }
+};
+
+class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
+  static constexpr const char *pyClassName = "AffineBinaryExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  PyAffineExpr lhs() {
+    MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
+    return PyAffineExpr(getContext(), lhsExpr);
+  }
+
+  PyAffineExpr rhs() {
+    MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
+    return PyAffineExpr(getContext(), rhsExpr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
+    c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+  }
+};
+
+class PyAffineAddExpr
+    : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
+  static constexpr const char *pyClassName = "AffineAddExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
+    return PyAffineAddExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineAddExpr::get);
+  }
+};
+
+class PyAffineMulExpr
+    : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
+  static constexpr const char *pyClassName = "AffineMulExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
+    return PyAffineMulExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineMulExpr::get);
+  }
+};
+
+class PyAffineModExpr
+    : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
+  static constexpr const char *pyClassName = "AffineModExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
+    return PyAffineModExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineModExpr::get);
+  }
+};
+
+class PyAffineFloorDivExpr
+    : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
+  static constexpr const char *pyClassName = "AffineFloorDivExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
+    return PyAffineFloorDivExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineFloorDivExpr::get);
+  }
+};
+
+class PyAffineCeilDivExpr
+    : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
+  static constexpr const char *pyClassName = "AffineCeilDivExpr";
+  using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+  static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+    MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
+    return PyAffineCeilDivExpr(lhs.getContext(), expr);
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static("get", &PyAffineCeilDivExpr::get);
+  }
+};
+} // namespace
+
+bool PyAffineExpr::operator==(const PyAffineExpr &other) {
+  return mlirAffineExprEqual(affineExpr, other.affineExpr);
+}
+
+py::object PyAffineExpr::getCapsule() {
+  return py::reinterpret_steal<py::object>(
+      mlirPythonAffineExprToCapsule(*this));
+}
+
+PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
+  MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
+  if (mlirAffineExprIsNull(rawAffineExpr))
+    throw py::error_already_set();
+  return PyAffineExpr(
+      PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
+      rawAffineExpr);
+}
+
 //------------------------------------------------------------------------------
 // PyAffineMap.
 //------------------------------------------------------------------------------
@@ -3414,6 +3646,94 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyAffineExpr and derived classes.
+  //----------------------------------------------------------------------------
+  py::class_<PyAffineExpr>(m, "AffineExpr")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyAffineExpr::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
+      .def("__add__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineAddExpr::get(self, other);
+           })
+      .def("__mul__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineMulExpr::get(self, other);
+           })
+      .def("__mod__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             return PyAffineModExpr::get(self, other);
+           })
+      .def("__sub__",
+           [](PyAffineExpr &self, PyAffineExpr &other) {
+             auto negOne =
+                 PyAffineConstantExpr::get(-1, *self.getContext().get());
+             return PyAffineAddExpr::get(self,
+                                         PyAffineMulExpr::get(negOne, other));
+           })
+      .def("__eq__", [](PyAffineExpr &self,
+                        PyAffineExpr &other) { return self == other; })
+      .def("__eq__",
+           [](PyAffineExpr &self, py::object &other) { return false; })
+      .def("__str__",
+           [](PyAffineExpr &self) {
+             PyPrintAccumulator printAccum;
+             mlirAffineExprPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             return printAccum.join();
+           })
+      .def("__repr__",
+           [](PyAffineExpr &self) {
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("AffineExpr(");
+             mlirAffineExprPrint(self, printAccum.getCallback(),
+                                 printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly(
+          "context",
+          [](PyAffineExpr &self) { return self.getContext().getObject(); })
+      .def_static(
+          "get_add", &PyAffineAddExpr::get,
+          "Gets an affine expression containing a sum of two expressions.")
+      .def_static(
+          "get_mul", &PyAffineMulExpr::get,
+          "Gets an affine expression containing a product of two expressions.")
+      .def_static("get_mod", &PyAffineModExpr::get,
+                  "Gets an affine expression containing the modulo of dividing "
+                  "one expression by another.")
+      .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
+                  "Gets an affine expression containing the rounded-down "
+                  "result of dividing one expression by another.")
+      .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
+                  "Gets an affine expression containing the rounded-up result "
+                  "of dividing one expression by another.")
+      .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
+                  py::arg("context") = py::none(),
+                  "Gets a constant affine expression with the given value.")
+      .def_static(
+          "get_dim", &PyAffineDimExpr::get, py::arg("position"),
+          py::arg("context") = py::none(),
+          "Gets an affine expression of a dimension at the given position.")
+      .def_static(
+          "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
+          py::arg("context") = py::none(),
+          "Gets an affine expression of a symbol at the given position.")
+      .def(
+          "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
+          kDumpDocstring);
+  PyAffineConstantExpr::bind(m);
+  PyAffineDimExpr::bind(m);
+  PyAffineSymbolExpr::bind(m);
+  PyAffineBinaryExpr::bind(m);
+  PyAffineAddExpr::bind(m);
+  PyAffineMulExpr::bind(m);
+  PyAffineModExpr::bind(m);
+  PyAffineFloorDivExpr::bind(m);
+  PyAffineCeilDivExpr::bind(m);
+
   //----------------------------------------------------------------------------
   // Mapping of PyAffineMap.
   //----------------------------------------------------------------------------

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 3438dd8c0270..e789f536a829 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -13,6 +13,7 @@
 
 #include "PybindUtils.h"
 
+#include "mlir-c/AffineExpr.h"
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
 #include "llvm/ADT/DenseMap.h"
@@ -668,6 +669,34 @@ class PyValue {
   MlirValue value;
 };
 
+/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
+class PyAffineExpr : public BaseContextObject {
+public:
+  PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+      : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
+  bool operator==(const PyAffineExpr &other);
+  operator MlirAffineExpr() const { return affineExpr; }
+  MlirAffineExpr get() const { return affineExpr; }
+
+  /// Gets a capsule wrapping the void* within the MlirAffineExpr.
+  pybind11::object getCapsule();
+
+  /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
+  /// Note that PyAffineExpr instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
+  /// is taken by calling this function.
+  static PyAffineExpr createFromCapsule(pybind11::object capsule);
+
+  PyAffineExpr add(const PyAffineExpr &other) const;
+  PyAffineExpr mul(const PyAffineExpr &other) const;
+  PyAffineExpr floorDiv(const PyAffineExpr &other) const;
+  PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
+  PyAffineExpr mod(const PyAffineExpr &other) const;
+
+private:
+  MlirAffineExpr affineExpr;
+};
+
 class PyAffineMap : public BaseContextObject {
 public:
   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)

diff  --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 01793192b05c..2d8bc3ce569a 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -21,6 +21,10 @@ MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr) {
   return wrap(unwrap(affineExpr).getContext());
 }
 
+bool mlirAffineExprEqual(MlirAffineExpr lhs, MlirAffineExpr rhs) {
+  return unwrap(lhs) == unwrap(rhs);
+}
+
 void mlirAffineExprPrint(MlirAffineExpr affineExpr, MlirStringCallback callback,
                          void *userData) {
   mlir::detail::CallbackOstream stream(callback, userData);
@@ -56,6 +60,10 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
 
+bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) {
+  return unwrap(affineExpr).isa<AffineDimExpr>();
+}
+
 MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) {
   return wrap(getAffineDimExpr(position, unwrap(ctx)));
 }
@@ -68,6 +76,10 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) {
 // Affine Symbol Expression.
 //===----------------------------------------------------------------------===//
 
+bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) {
+  return unwrap(affineExpr).isa<AffineSymbolExpr>();
+}
+
 MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) {
   return wrap(getAffineSymbolExpr(position, unwrap(ctx)));
 }
@@ -80,6 +92,10 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) {
 // Affine Constant Expression.
 //===----------------------------------------------------------------------===//
 
+bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) {
+  return unwrap(affineExpr).isa<AffineConstantExpr>();
+}
+
 MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) {
   return wrap(getAffineConstantExpr(constant, unwrap(ctx)));
 }
@@ -159,6 +175,10 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
 // Affine Binary Operation Expression.
 //===----------------------------------------------------------------------===//
 
+bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) {
+  return unwrap(affineExpr).isa<AffineBinaryOpExpr>();
+}
+
 MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) {
   return wrap(unwrap(affineExpr).cast<AffineBinaryOpExpr>().getLHS());
 }

diff  --git a/mlir/test/Bindings/Python/ir_affine_expr.py b/mlir/test/Bindings/Python/ir_affine_expr.py
new file mode 100644
index 000000000000..eb58579448ca
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_affine_expr.py
@@ -0,0 +1,275 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testAffineExprCapsule
+def testAffineExprCapsule():
+  with Context() as ctx:
+    affine_expr = AffineExpr.get_constant(42)
+
+  affine_expr_capsule = affine_expr._CAPIPtr
+  # CHECK: capsule object
+  # CHECK: mlir.ir.AffineExpr._CAPIPtr
+  print(affine_expr_capsule)
+
+  affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
+  assert affine_expr == affine_expr_2
+  assert affine_expr_2.context == ctx
+
+run(testAffineExprCapsule)
+
+
+# CHECK-LABEL: TEST: testAffineExprEq
+def testAffineExprEq():
+  with Context():
+    a1 = AffineExpr.get_constant(42)
+    a2 = AffineExpr.get_constant(42)
+    a3 = AffineExpr.get_constant(43)
+    # CHECK: True
+    print(a1 == a1)
+    # CHECK: True
+    print(a1 == a2)
+    # CHECK: False
+    print(a1 == a3)
+    # CHECK: False
+    print(a1 == None)
+    # CHECK: False
+    print(a1 == "foo")
+
+run(testAffineExprEq)
+
+
+# CHECK-LABEL: TEST: testAffineExprContext
+def testAffineExprContext():
+  with Context():
+    a1 = AffineExpr.get_constant(42)
+  with Context():
+    a2 = AffineExpr.get_constant(42)
+
+  # CHECK: False
+  print(a1 == a2)
+
+run(testAffineExprContext)
+
+
+# CHECK-LABEL: TEST: testAffineExprConstant
+def testAffineExprConstant():
+  with Context():
+    a1 = AffineExpr.get_constant(42)
+    # CHECK: 42
+    print(a1.value)
+    # CHECK: 42
+    print(a1)
+
+    a2 = AffineConstantExpr.get(42)
+    # CHECK: 42
+    print(a2.value)
+    # CHECK: 42
+    print(a2)
+
+    assert a1 == a2
+
+run(testAffineExprConstant)
+
+
+# CHECK-LABEL: TEST: testAffineExprDim
+def testAffineExprDim():
+  with Context():
+    d1 = AffineExpr.get_dim(1)
+    d11 = AffineDimExpr.get(1)
+    d2 = AffineDimExpr.get(2)
+
+    # CHECK: 1
+    print(d1.position)
+    # CHECK: d1
+    print(d1)
+
+    # CHECK: 2
+    print(d2.position)
+    # CHECK: d2
+    print(d2)
+
+    assert d1 == d11
+    assert d1 != d2
+
+run(testAffineExprDim)
+
+
+# CHECK-LABEL: TEST: testAffineExprSymbol
+def testAffineExprSymbol():
+  with Context():
+    s1 = AffineExpr.get_symbol(1)
+    s11 = AffineSymbolExpr.get(1)
+    s2 = AffineSymbolExpr.get(2)
+
+    # CHECK: 1
+    print(s1.position)
+    # CHECK: s1
+    print(s1)
+
+    # CHECK: 2
+    print(s2.position)
+    # CHEKC: s2
+    print(s2)
+
+    assert s1 == s11
+    assert s1 != s2
+
+run(testAffineExprSymbol)
+
+
+# CHECK-LABEL: TEST: testAffineAddExpr
+def testAffineAddExpr():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    d2 = AffineDimExpr.get(2)
+    d12 = AffineExpr.get_add(d1, d2)
+    # CHECK: d1 + d2
+    print(d12)
+
+    d12op = d1 + d2
+    # CHECK: d1 + d2
+    print(d12op)
+
+    assert d12 == d12op
+    assert d12.lhs == d1
+    assert d12.rhs == d2
+
+run(testAffineAddExpr)
+
+
+# CHECK-LABEL: TEST: testAffineMulExpr
+def testAffineMulExpr():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+    expr = AffineExpr.get_mul(d1, c2)
+    # CHECK: d1 * 2
+    print(expr)
+
+    # CHECK: d1 * 2
+    op = d1 * c2
+    print(op)
+
+    assert expr == op
+    assert expr.lhs == d1
+    assert expr.rhs == c2
+
+run(testAffineMulExpr)
+
+
+# CHECK-LABEL: TEST: testAffineModExpr
+def testAffineModExpr():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+    expr = AffineExpr.get_mod(d1, c2)
+    # CHECK: d1 mod 2
+    print(expr)
+
+    # CHECK: d1 mod 2
+    op = d1 % c2
+    print(op)
+
+    assert expr == op
+    assert expr.lhs == d1
+    assert expr.rhs == c2
+
+run(testAffineModExpr)
+
+
+# CHECK-LABEL: TEST: testAffineFloorDivExpr
+def testAffineFloorDivExpr():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+    expr = AffineExpr.get_floor_div(d1, c2)
+    # CHECK: d1 floordiv 2
+    print(expr)
+
+    assert expr.lhs == d1
+    assert expr.rhs == c2
+
+run(testAffineFloorDivExpr)
+
+
+# CHECK-LABEL: TEST: testAffineCeilDivExpr
+def testAffineCeilDivExpr():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+    expr = AffineExpr.get_ceil_div(d1, c2)
+    # CHECK: d1 ceildiv 2
+    print(expr)
+
+    assert expr.lhs == d1
+    assert expr.rhs == c2
+
+run(testAffineCeilDivExpr)
+
+
+# CHECK-LABEL: TEST: testAffineExprSub
+def testAffineExprSub():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    d2 = AffineDimExpr.get(2)
+    expr = d1 - d2
+    # CHECK: d1 - d2
+    print(expr)
+
+    assert expr.lhs == d1
+    rhs = AffineMulExpr(expr.rhs)
+    # CHECK: d2
+    print(rhs.lhs)
+    # CHECK: -1
+    print(rhs.rhs)
+
+run(testAffineExprSub)
+
+
+def testClassHierarchy():
+  with Context():
+    d1 = AffineDimExpr.get(1)
+    c2 = AffineConstantExpr.get(2)
+    add = AffineAddExpr.get(d1, c2)
+    mul = AffineMulExpr.get(d1, c2)
+    mod = AffineModExpr.get(d1, c2)
+    floor_div = AffineFloorDivExpr.get(d1, c2)
+    ceil_div = AffineCeilDivExpr.get(d1, c2)
+
+    # CHECK: False
+    print(isinstance(d1, AffineBinaryExpr))
+    # CHECK: False
+    print(isinstance(c2, AffineBinaryExpr))
+    # CHECK: True
+    print(isinstance(add, AffineBinaryExpr))
+    # CHECK: True
+    print(isinstance(mul, AffineBinaryExpr))
+    # CHECK: True
+    print(isinstance(mod, AffineBinaryExpr))
+    # CHECK: True
+    print(isinstance(floor_div, AffineBinaryExpr))
+    # CHECK: True
+    print(isinstance(ceil_div, AffineBinaryExpr))
+
+    try:
+      AffineBinaryExpr(d1)
+    except ValueError as e:
+      # CHECK: Cannot cast affine expression to AffineBinaryExpr
+      print(e)
+
+    try:
+      AffineBinaryExpr(c2)
+    except ValueError as e:
+      # CHECK: Cannot cast affine expression to AffineBinaryExpr
+      print(e)
+
+run(testClassHierarchy)

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 434f272de059..550f799440f9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1251,6 +1251,27 @@ int printAffineExpr(MlirContext ctx) {
   if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
     return 13;
 
+  if (!mlirAffineExprIsABinary(affineAddExpr))
+    return 14;
+
+  // Test other 'IsA' method on affine expressions.
+  if (!mlirAffineExprIsAConstant(affineConstantExpr))
+    return 15;
+
+  if (!mlirAffineExprIsADim(affineDimExpr))
+    return 16;
+
+  if (!mlirAffineExprIsASymbol(affineSymbolExpr))
+    return 17;
+
+  // Test equality and nullity.
+  MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
+  if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
+    return 18;
+
+  if (mlirAffineExprIsNull(affineDimExpr))
+    return 19;
+
   return 0;
 }
 


        


More information about the Mlir-commits mailing list