[Mlir-commits] [mlir] 74628c4 - [mlir] Add Python bindings for AffineExpr
Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 11 10:57:25 PST 2021
Author: Alex Zinenko
Date: 2021-01-11T19:57:13+01:00
New Revision: 74628c43053b482f35f0f1e6b4eac743fbe425e5
URL: https://github.com/llvm/llvm-project/commit/74628c43053b482f35f0f1e6b4eac743fbe425e5
DIFF: https://github.com/llvm/llvm-project/commit/74628c43053b482f35f0f1e6b4eac743fbe425e5.diff
LOG: [mlir] Add Python bindings for AffineExpr
This adds the Python bindings for AffineExpr and a couple of utility functions
to the C API. AffineExpr is a top-level context-owned object and is modeled
similarly to attributes and types. It is required, e.g., to build layout maps
of the built-in memref type.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D94225
Added:
mlir/test/Bindings/Python/ir_affine_expr.py
Modified:
mlir/include/mlir-c/AffineExpr.h
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/CAPI/IR/AffineExpr.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 93b8e832b44f..2eb8ae03e03d 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -45,6 +45,16 @@ DEFINE_C_API_STRUCT(MlirAffineExpr, const void);
MLIR_CAPI_EXPORTED MlirContext
mlirAffineExprGetContext(MlirAffineExpr affineExpr);
+/// Returns `true` if the two affine expressions are equal.
+MLIR_CAPI_EXPORTED bool mlirAffineExprEqual(MlirAffineExpr lhs,
+ MlirAffineExpr rhs);
+
+/// Returns `true` if the given affine expression is a null expression. Note
+/// constant zero is not a null expression.
+inline bool mlirAffineExprIsNull(MlirAffineExpr affineExpr) {
+ return affineExpr.ptr == NULL;
+}
+
/** Prints an affine expression by sending chunks of the string representation
* and forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
@@ -82,6 +92,9 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
+/// Checks whether the given affine expression is a dimension expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsADim(MlirAffineExpr affineExpr);
+
/// Creates an affine dimension expression with 'position' in the context.
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx,
intptr_t position);
@@ -94,6 +107,9 @@ mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr);
// Affine Symbol Expression.
//===----------------------------------------------------------------------===//
+/// Checks whether the given affine expression is a symbol expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr);
+
/// Creates an affine symbol expression with 'position' in the context.
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx,
intptr_t position);
@@ -106,6 +122,9 @@ mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr);
// Affine Constant Expression.
//===----------------------------------------------------------------------===//
+/// Checks whether the given affine expression is a constant expression.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr);
+
/// Creates an affine constant expression with 'constant' in the context.
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx,
int64_t constant);
@@ -173,6 +192,9 @@ MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
// Affine Binary Operation Expression.
//===----------------------------------------------------------------------===//
+/// Checks whether the given affine expression is binary.
+MLIR_CAPI_EXPORTED bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr);
+
/** Returns the left hand side affine expression of the given affine binary
* operation expression. */
MLIR_CAPI_EXPORTED MlirAffineExpr
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index ae9d3a84a0a3..d1eda4202345 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -23,10 +23,12 @@
#include <Python.h>
+#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
+#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
@@ -72,6 +74,25 @@
extern "C" {
#endif
+/** Creates a capsule object encapsulating the raw C-API MlirAffineExpr. The
+ * returned capsule does not extend or affect ownership of any Python objects
+ * that reference the expression in any way.
+ */
+static inline PyObject *mlirPythonAffineExprToCapsule(MlirAffineExpr expr) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(expr),
+ MLIR_PYTHON_CAPSULE_AFFINE_EXPR, NULL);
+}
+
+/** Extracts an MlirAffineExpr from a capsule as produced from
+ * mlirPythonAffineExprToCapsule. If the capsule is not of the right type, then
+ * a null expression is returned (as checked via mlirAffineExprIsNull). In such
+ * a case, the Python APIs will have already set an error. */
+static inline MlirAffineExpr mlirPythonCapsuleToAffineExpr(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_AFFINE_EXPR);
+ MlirAffineExpr expr = {ptr};
+ return expr;
+}
+
/** Creates a capsule object encapsulating the raw C-API MlirAttribute.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the attribute in any way.
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2dcee494715d..2d18a7a488e7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2710,6 +2710,238 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
} // namespace
+//------------------------------------------------------------------------------
+// PyAffineExpr and subclasses.
+//------------------------------------------------------------------------------
+
+namespace {
+/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
+/// and should be castable from it. Intermediate hierarchy classes can be
+/// modeled by specifying BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAffineExpr>
+class PyConcreteAffineExpr : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirAffineExpr);
+
+ PyConcreteAffineExpr() = default;
+ PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+ : BaseTy(std::move(contextRef), affineExpr) {}
+ PyConcreteAffineExpr(PyAffineExpr &orig)
+ : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
+
+ static MlirAffineExpr castFrom(PyAffineExpr &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError,
+ Twine("Cannot cast affine expression to ") +
+ DerivedTy::pyClassName + " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(py::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(py::init<PyAffineExpr &>());
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
+ static constexpr const char *pyClassName = "AffineConstantExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineConstantExpr get(intptr_t value,
+ DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr =
+ mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
+ return PyAffineConstantExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+ return mlirAffineConstantExprGetValue(self);
+ });
+ }
+};
+
+class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
+ static constexpr const char *pyClassName = "AffineDimExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
+ return PyAffineDimExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+ return mlirAffineDimExprGetPosition(self);
+ });
+ }
+};
+
+class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
+ static constexpr const char *pyClassName = "AffineSymbolExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
+ return PyAffineSymbolExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+ return mlirAffineSymbolExprGetPosition(self);
+ });
+ }
+};
+
+class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
+ static constexpr const char *pyClassName = "AffineBinaryExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ PyAffineExpr lhs() {
+ MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
+ return PyAffineExpr(getContext(), lhsExpr);
+ }
+
+ PyAffineExpr rhs() {
+ MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
+ return PyAffineExpr(getContext(), rhsExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
+ c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+ }
+};
+
+class PyAffineAddExpr
+ : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
+ static constexpr const char *pyClassName = "AffineAddExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
+ return PyAffineAddExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineAddExpr::get);
+ }
+};
+
+class PyAffineMulExpr
+ : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
+ static constexpr const char *pyClassName = "AffineMulExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
+ return PyAffineMulExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineMulExpr::get);
+ }
+};
+
+class PyAffineModExpr
+ : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
+ static constexpr const char *pyClassName = "AffineModExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
+ return PyAffineModExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineModExpr::get);
+ }
+};
+
+class PyAffineFloorDivExpr
+ : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
+ static constexpr const char *pyClassName = "AffineFloorDivExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
+ return PyAffineFloorDivExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineFloorDivExpr::get);
+ }
+};
+
+class PyAffineCeilDivExpr
+ : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
+ static constexpr const char *pyClassName = "AffineCeilDivExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
+ return PyAffineCeilDivExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineCeilDivExpr::get);
+ }
+};
+} // namespace
+
+bool PyAffineExpr::operator==(const PyAffineExpr &other) {
+ return mlirAffineExprEqual(affineExpr, other.affineExpr);
+}
+
+py::object PyAffineExpr::getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonAffineExprToCapsule(*this));
+}
+
+PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
+ MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
+ if (mlirAffineExprIsNull(rawAffineExpr))
+ throw py::error_already_set();
+ return PyAffineExpr(
+ PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
+ rawAffineExpr);
+}
+
//------------------------------------------------------------------------------
// PyAffineMap.
//------------------------------------------------------------------------------
@@ -3414,6 +3646,94 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyRegionIterator::bind(m);
PyRegionList::bind(m);
+ //----------------------------------------------------------------------------
+ // Mapping of PyAffineExpr and derived classes.
+ //----------------------------------------------------------------------------
+ py::class_<PyAffineExpr>(m, "AffineExpr")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyAffineExpr::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
+ .def("__add__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineAddExpr::get(self, other);
+ })
+ .def("__mul__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineMulExpr::get(self, other);
+ })
+ .def("__mod__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineModExpr::get(self, other);
+ })
+ .def("__sub__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ auto negOne =
+ PyAffineConstantExpr::get(-1, *self.getContext().get());
+ return PyAffineAddExpr::get(self,
+ PyAffineMulExpr::get(negOne, other));
+ })
+ .def("__eq__", [](PyAffineExpr &self,
+ PyAffineExpr &other) { return self == other; })
+ .def("__eq__",
+ [](PyAffineExpr &self, py::object &other) { return false; })
+ .def("__str__",
+ [](PyAffineExpr &self) {
+ PyPrintAccumulator printAccum;
+ mlirAffineExprPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ })
+ .def("__repr__",
+ [](PyAffineExpr &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("AffineExpr(");
+ mlirAffineExprPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly(
+ "context",
+ [](PyAffineExpr &self) { return self.getContext().getObject(); })
+ .def_static(
+ "get_add", &PyAffineAddExpr::get,
+ "Gets an affine expression containing a sum of two expressions.")
+ .def_static(
+ "get_mul", &PyAffineMulExpr::get,
+ "Gets an affine expression containing a product of two expressions.")
+ .def_static("get_mod", &PyAffineModExpr::get,
+ "Gets an affine expression containing the modulo of dividing "
+ "one expression by another.")
+ .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
+ "Gets an affine expression containing the rounded-down "
+ "result of dividing one expression by another.")
+ .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
+ "Gets an affine expression containing the rounded-up result "
+ "of dividing one expression by another.")
+ .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
+ py::arg("context") = py::none(),
+ "Gets a constant affine expression with the given value.")
+ .def_static(
+ "get_dim", &PyAffineDimExpr::get, py::arg("position"),
+ py::arg("context") = py::none(),
+ "Gets an affine expression of a dimension at the given position.")
+ .def_static(
+ "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
+ py::arg("context") = py::none(),
+ "Gets an affine expression of a symbol at the given position.")
+ .def(
+ "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
+ kDumpDocstring);
+ PyAffineConstantExpr::bind(m);
+ PyAffineDimExpr::bind(m);
+ PyAffineSymbolExpr::bind(m);
+ PyAffineBinaryExpr::bind(m);
+ PyAffineAddExpr::bind(m);
+ PyAffineMulExpr::bind(m);
+ PyAffineModExpr::bind(m);
+ PyAffineFloorDivExpr::bind(m);
+ PyAffineCeilDivExpr::bind(m);
+
//----------------------------------------------------------------------------
// Mapping of PyAffineMap.
//----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 3438dd8c0270..e789f536a829 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -13,6 +13,7 @@
#include "PybindUtils.h"
+#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/DenseMap.h"
@@ -668,6 +669,34 @@ class PyValue {
MlirValue value;
};
+/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
+class PyAffineExpr : public BaseContextObject {
+public:
+ PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+ : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
+ bool operator==(const PyAffineExpr &other);
+ operator MlirAffineExpr() const { return affineExpr; }
+ MlirAffineExpr get() const { return affineExpr; }
+
+ /// Gets a capsule wrapping the void* within the MlirAffineExpr.
+ pybind11::object getCapsule();
+
+ /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
+ /// Note that PyAffineExpr instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
+ /// is taken by calling this function.
+ static PyAffineExpr createFromCapsule(pybind11::object capsule);
+
+ PyAffineExpr add(const PyAffineExpr &other) const;
+ PyAffineExpr mul(const PyAffineExpr &other) const;
+ PyAffineExpr floorDiv(const PyAffineExpr &other) const;
+ PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
+ PyAffineExpr mod(const PyAffineExpr &other) const;
+
+private:
+ MlirAffineExpr affineExpr;
+};
+
class PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 01793192b05c..2d8bc3ce569a 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -21,6 +21,10 @@ MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr) {
return wrap(unwrap(affineExpr).getContext());
}
+bool mlirAffineExprEqual(MlirAffineExpr lhs, MlirAffineExpr rhs) {
+ return unwrap(lhs) == unwrap(rhs);
+}
+
void mlirAffineExprPrint(MlirAffineExpr affineExpr, MlirStringCallback callback,
void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
@@ -56,6 +60,10 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
+bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) {
+ return unwrap(affineExpr).isa<AffineDimExpr>();
+}
+
MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) {
return wrap(getAffineDimExpr(position, unwrap(ctx)));
}
@@ -68,6 +76,10 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) {
// Affine Symbol Expression.
//===----------------------------------------------------------------------===//
+bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) {
+ return unwrap(affineExpr).isa<AffineSymbolExpr>();
+}
+
MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) {
return wrap(getAffineSymbolExpr(position, unwrap(ctx)));
}
@@ -80,6 +92,10 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) {
// Affine Constant Expression.
//===----------------------------------------------------------------------===//
+bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) {
+ return unwrap(affineExpr).isa<AffineConstantExpr>();
+}
+
MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) {
return wrap(getAffineConstantExpr(constant, unwrap(ctx)));
}
@@ -159,6 +175,10 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
// Affine Binary Operation Expression.
//===----------------------------------------------------------------------===//
+bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) {
+ return unwrap(affineExpr).isa<AffineBinaryOpExpr>();
+}
+
MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) {
return wrap(unwrap(affineExpr).cast<AffineBinaryOpExpr>().getLHS());
}
diff --git a/mlir/test/Bindings/Python/ir_affine_expr.py b/mlir/test/Bindings/Python/ir_affine_expr.py
new file mode 100644
index 000000000000..eb58579448ca
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_affine_expr.py
@@ -0,0 +1,275 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testAffineExprCapsule
+def testAffineExprCapsule():
+ with Context() as ctx:
+ affine_expr = AffineExpr.get_constant(42)
+
+ affine_expr_capsule = affine_expr._CAPIPtr
+ # CHECK: capsule object
+ # CHECK: mlir.ir.AffineExpr._CAPIPtr
+ print(affine_expr_capsule)
+
+ affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
+ assert affine_expr == affine_expr_2
+ assert affine_expr_2.context == ctx
+
+run(testAffineExprCapsule)
+
+
+# CHECK-LABEL: TEST: testAffineExprEq
+def testAffineExprEq():
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ a2 = AffineExpr.get_constant(42)
+ a3 = AffineExpr.get_constant(43)
+ # CHECK: True
+ print(a1 == a1)
+ # CHECK: True
+ print(a1 == a2)
+ # CHECK: False
+ print(a1 == a3)
+ # CHECK: False
+ print(a1 == None)
+ # CHECK: False
+ print(a1 == "foo")
+
+run(testAffineExprEq)
+
+
+# CHECK-LABEL: TEST: testAffineExprContext
+def testAffineExprContext():
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ with Context():
+ a2 = AffineExpr.get_constant(42)
+
+ # CHECK: False
+ print(a1 == a2)
+
+run(testAffineExprContext)
+
+
+# CHECK-LABEL: TEST: testAffineExprConstant
+def testAffineExprConstant():
+ with Context():
+ a1 = AffineExpr.get_constant(42)
+ # CHECK: 42
+ print(a1.value)
+ # CHECK: 42
+ print(a1)
+
+ a2 = AffineConstantExpr.get(42)
+ # CHECK: 42
+ print(a2.value)
+ # CHECK: 42
+ print(a2)
+
+ assert a1 == a2
+
+run(testAffineExprConstant)
+
+
+# CHECK-LABEL: TEST: testAffineExprDim
+def testAffineExprDim():
+ with Context():
+ d1 = AffineExpr.get_dim(1)
+ d11 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+
+ # CHECK: 1
+ print(d1.position)
+ # CHECK: d1
+ print(d1)
+
+ # CHECK: 2
+ print(d2.position)
+ # CHECK: d2
+ print(d2)
+
+ assert d1 == d11
+ assert d1 != d2
+
+run(testAffineExprDim)
+
+
+# CHECK-LABEL: TEST: testAffineExprSymbol
+def testAffineExprSymbol():
+ with Context():
+ s1 = AffineExpr.get_symbol(1)
+ s11 = AffineSymbolExpr.get(1)
+ s2 = AffineSymbolExpr.get(2)
+
+ # CHECK: 1
+ print(s1.position)
+ # CHECK: s1
+ print(s1)
+
+ # CHECK: 2
+ print(s2.position)
+ # CHEKC: s2
+ print(s2)
+
+ assert s1 == s11
+ assert s1 != s2
+
+run(testAffineExprSymbol)
+
+
+# CHECK-LABEL: TEST: testAffineAddExpr
+def testAffineAddExpr():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ d12 = AffineExpr.get_add(d1, d2)
+ # CHECK: d1 + d2
+ print(d12)
+
+ d12op = d1 + d2
+ # CHECK: d1 + d2
+ print(d12op)
+
+ assert d12 == d12op
+ assert d12.lhs == d1
+ assert d12.rhs == d2
+
+run(testAffineAddExpr)
+
+
+# CHECK-LABEL: TEST: testAffineMulExpr
+def testAffineMulExpr():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_mul(d1, c2)
+ # CHECK: d1 * 2
+ print(expr)
+
+ # CHECK: d1 * 2
+ op = d1 * c2
+ print(op)
+
+ assert expr == op
+ assert expr.lhs == d1
+ assert expr.rhs == c2
+
+run(testAffineMulExpr)
+
+
+# CHECK-LABEL: TEST: testAffineModExpr
+def testAffineModExpr():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_mod(d1, c2)
+ # CHECK: d1 mod 2
+ print(expr)
+
+ # CHECK: d1 mod 2
+ op = d1 % c2
+ print(op)
+
+ assert expr == op
+ assert expr.lhs == d1
+ assert expr.rhs == c2
+
+run(testAffineModExpr)
+
+
+# CHECK-LABEL: TEST: testAffineFloorDivExpr
+def testAffineFloorDivExpr():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_floor_div(d1, c2)
+ # CHECK: d1 floordiv 2
+ print(expr)
+
+ assert expr.lhs == d1
+ assert expr.rhs == c2
+
+run(testAffineFloorDivExpr)
+
+
+# CHECK-LABEL: TEST: testAffineCeilDivExpr
+def testAffineCeilDivExpr():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ expr = AffineExpr.get_ceil_div(d1, c2)
+ # CHECK: d1 ceildiv 2
+ print(expr)
+
+ assert expr.lhs == d1
+ assert expr.rhs == c2
+
+run(testAffineCeilDivExpr)
+
+
+# CHECK-LABEL: TEST: testAffineExprSub
+def testAffineExprSub():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ d2 = AffineDimExpr.get(2)
+ expr = d1 - d2
+ # CHECK: d1 - d2
+ print(expr)
+
+ assert expr.lhs == d1
+ rhs = AffineMulExpr(expr.rhs)
+ # CHECK: d2
+ print(rhs.lhs)
+ # CHECK: -1
+ print(rhs.rhs)
+
+run(testAffineExprSub)
+
+
+def testClassHierarchy():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ add = AffineAddExpr.get(d1, c2)
+ mul = AffineMulExpr.get(d1, c2)
+ mod = AffineModExpr.get(d1, c2)
+ floor_div = AffineFloorDivExpr.get(d1, c2)
+ ceil_div = AffineCeilDivExpr.get(d1, c2)
+
+ # CHECK: False
+ print(isinstance(d1, AffineBinaryExpr))
+ # CHECK: False
+ print(isinstance(c2, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(add, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(mul, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(mod, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(floor_div, AffineBinaryExpr))
+ # CHECK: True
+ print(isinstance(ceil_div, AffineBinaryExpr))
+
+ try:
+ AffineBinaryExpr(d1)
+ except ValueError as e:
+ # CHECK: Cannot cast affine expression to AffineBinaryExpr
+ print(e)
+
+ try:
+ AffineBinaryExpr(c2)
+ except ValueError as e:
+ # CHECK: Cannot cast affine expression to AffineBinaryExpr
+ print(e)
+
+run(testClassHierarchy)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 434f272de059..550f799440f9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1251,6 +1251,27 @@ int printAffineExpr(MlirContext ctx) {
if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
return 13;
+ if (!mlirAffineExprIsABinary(affineAddExpr))
+ return 14;
+
+ // Test other 'IsA' method on affine expressions.
+ if (!mlirAffineExprIsAConstant(affineConstantExpr))
+ return 15;
+
+ if (!mlirAffineExprIsADim(affineDimExpr))
+ return 16;
+
+ if (!mlirAffineExprIsASymbol(affineSymbolExpr))
+ return 17;
+
+ // Test equality and nullity.
+ MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
+ if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
+ return 18;
+
+ if (mlirAffineExprIsNull(affineDimExpr))
+ return 19;
+
return 0;
}
More information about the Mlir-commits
mailing list