[Mlir-commits] [llvm] [mlir] [Python] Develop python bindings for Presburger library (PR #113233)

Sagar Shelke llvmlistbot at llvm.org
Mon Oct 21 16:41:02 PDT 2024


https://github.com/shelkesagar29 created https://github.com/llvm/llvm-project/pull/113233


This MR is work in progress.
MR is created so that community is aware of ongoing work and repeated work is avoided.

>From 1a4649cf1f40b638a7217602d300163822f712b8 Mon Sep 17 00:00:00 2001
From: Sagar Shelke <shelkesagar29 at yahoo.com>
Date: Thu, 17 Oct 2024 23:29:01 +0000
Subject: [PATCH] [presburger] Develope python bindings for presburger c++
 library

This MR is work in progress.
---
 llvm/include/llvm/ADT/DynamicAPInt.h          |   7 +
 mlir/include/mlir-c/Presburger.h              | 145 +++++++++++
 .../Analysis/Presburger/IntegerRelation.h     |   9 +
 mlir/include/mlir/CAPI/Presburger.h           |  29 +++
 mlir/lib/Bindings/Python/Presburger.cpp       | 244 ++++++++++++++++++
 mlir/lib/CAPI/CMakeLists.txt                  |   1 +
 mlir/lib/CAPI/Presburger/CMakeLists.txt       |   6 +
 mlir/lib/CAPI/Presburger/Presburger.cpp       | 163 ++++++++++++
 mlir/python/CMakeLists.txt                    |  19 ++
 mlir/python/mlir/presburger.py                |  12 +
 mlir/test/python/presburger.py                |  47 ++++
 11 files changed, 682 insertions(+)
 create mode 100644 mlir/include/mlir-c/Presburger.h
 create mode 100644 mlir/include/mlir/CAPI/Presburger.h
 create mode 100644 mlir/lib/Bindings/Python/Presburger.cpp
 create mode 100644 mlir/lib/CAPI/Presburger/CMakeLists.txt
 create mode 100644 mlir/lib/CAPI/Presburger/Presburger.cpp
 create mode 100644 mlir/python/mlir/presburger.py
 create mode 100644 mlir/test/python/presburger.py

diff --git a/llvm/include/llvm/ADT/DynamicAPInt.h b/llvm/include/llvm/ADT/DynamicAPInt.h
index ff958d48e77317..7515522edd0bf6 100644
--- a/llvm/include/llvm/ADT/DynamicAPInt.h
+++ b/llvm/include/llvm/ADT/DynamicAPInt.h
@@ -217,6 +217,13 @@ class DynamicAPInt {
 
   raw_ostream &print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
+
+  void *getAsOpaquePointer() const { return const_cast<DynamicAPInt *>(this); }
+
+  static DynamicAPInt *getFromOpaquePointer(const void *Pointer) {
+    return const_cast<DynamicAPInt *>(
+        reinterpret_cast<const DynamicAPInt *>(Pointer));
+  }
 };
 
 inline raw_ostream &operator<<(raw_ostream &OS, const DynamicAPInt &X) {
diff --git a/mlir/include/mlir-c/Presburger.h b/mlir/include/mlir-c/Presburger.h
new file mode 100644
index 00000000000000..dd7343c160b5c6
--- /dev/null
+++ b/mlir/include/mlir-c/Presburger.h
@@ -0,0 +1,145 @@
+#ifndef MLIR_C_PRESBURGER_H
+#define MLIR_C_PRESBURGER_H
+#include "mlir-c/AffineExpr.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum MlirPresburgerVariableKind {
+  Symbol,
+  Local,
+  Domain,
+  Range,
+  SetDim = Range
+};
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+DEFINE_C_API_STRUCT(MlirPresburgerIntegerRelation, void);
+DEFINE_C_API_STRUCT(MlirPresburgerDynamicAPInt, const void);
+#undef DEFINE_C_API_STRUCT
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation creation/destruction and basic metadata operations
+//===----------------------------------------------------------------------===//
+
+/// Constructs a relation reserving memory for the specified number
+/// of constraints and variables.
+MLIR_CAPI_EXPORTED MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationCreate(unsigned numReservedInequalities,
+                                    unsigned numReservedEqualities,
+                                    unsigned numReservedCols);
+
+/// Constructs an IntegerRelation from a packed 2D matrix of tableau
+/// coefficients in row-major order. The first `numDomainVars` columns are
+/// considered domain and the remaining `numRangeVars` columns are domain
+/// variables.
+MLIR_CAPI_EXPORTED MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationCreateFromCoefficients(
+    const int64_t *inequalityCoefficients, unsigned numInequalities,
+    const int64_t *equalityCoefficients, unsigned numEqualities,
+    unsigned numDomainVars, unsigned numRangeVars,
+    unsigned numExtraReservedInequalities = 0,
+    unsigned numExtraReservedEqualities = 0, unsigned numExtraReservedCols = 0);
+
+/// Destroys an IntegerRelation.
+MLIR_CAPI_EXPORTED void
+mlirPresburgerIntegerRelationDestroy(MlirPresburgerIntegerRelation relation);
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation binary operations
+//===----------------------------------------------------------------------===//
+
+/// Return whether `lhs` and `rhs` are equal. This is integer-exact
+/// and somewhat expensive, since it uses the integer emptiness check
+/// (see IntegerRelation::findIntegerSample()).
+MLIR_CAPI_EXPORTED bool
+mlirPresburgerIntegerRelationIsEqual(MlirPresburgerIntegerRelation lhs,
+                                     MlirPresburgerIntegerRelation rhs);
+
+/// Return the intersection of the two relations.
+/// If there are locals, they will be merged.
+MLIR_CAPI_EXPORTED MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationIntersect(MlirPresburgerIntegerRelation lhs,
+                                       MlirPresburgerIntegerRelation rhs);
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation Tableau Inspection
+//===----------------------------------------------------------------------===//
+
+/// Returns the value at the specified equality row and column.
+MLIR_CAPI_EXPORTED MlirPresburgerDynamicAPInt
+mlirPresburgerIntegerRelationAtEq(unsigned i, unsigned j);
+
+/// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+/// value does not fit in an int64_t.
+MLIR_CAPI_EXPORTED int64_t mlirPresburgerIntegerRelationAtEq64(
+    MlirPresburgerIntegerRelation relation, unsigned row, unsigned col);
+
+/// Returns the value at the specified inequality row and column.
+MLIR_CAPI_EXPORTED MlirPresburgerDynamicAPInt
+mlirPresburgerIntegerRelationAtIneq(MlirPresburgerIntegerRelation relation,
+                                    unsigned row, unsigned col);
+
+MLIR_CAPI_EXPORTED int64_t mlirPresburgerIntegerRelationAtIneq64(
+    MlirPresburgerIntegerRelation relation, unsigned row, unsigned col);
+
+/// Returns the number of inequalities and equalities.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumConstraints(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of inequality constraints.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumInequalities(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of equality constraints.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumEqualities(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of columns classified as domain variables.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumDomainVars(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of columns classified as range variables.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumRangeVars(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of columns classified as symbol variables.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumSymbolVars(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the number of columns classified as local variables.
+MLIR_CAPI_EXPORTED unsigned mlirPresburgerIntegerRelationNumLocalVars(
+    MlirPresburgerIntegerRelation relation);
+
+/// Returns the total number of columns in the tableau.
+MLIR_CAPI_EXPORTED unsigned
+mlirPresburgerIntegerRelationNumCols(MlirPresburgerIntegerRelation relation);
+
+/// Return the VarKind of the var at the specified position.
+MLIR_CAPI_EXPORTED MlirPresburgerVariableKind
+mlirPresburgerIntegerRelationGetVarKindAt(unsigned pos);
+
+MLIR_CAPI_EXPORTED void
+mlirPresburgerIntegerRelationDump(MlirPresburgerIntegerRelation relation);
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation Tableau Manipulation
+//===----------------------------------------------------------------------===//
+/// Adds an equality with the given coefficients.
+MLIR_CAPI_EXPORTED void
+mlirPresburgerIntegerRelationAddEquality(const int64_t *coefficients,
+                                         size_t coefficientsSize);
+
+/// Adds an inequality with the given coefficients.
+MLIR_CAPI_EXPORTED void
+mlirPresburgerIntegerRelationAddInequality(const int64_t *coefficients,
+                                           size_t coefficientsSize);
+#ifdef __cplusplus
+}
+#endif
+#endif // MLIR_C_PRESBURGER_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index a27fc8c37eeda1..b58e2be164ec8f 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -753,6 +753,15 @@ class IntegerRelation {
   // false.
   bool isFullDim();
 
+  void *getAsOpaquePointer() const {
+    return const_cast<IntegerRelation *>(this);
+  }
+
+  static IntegerRelation *getFromOpaquePointer(const void *pointer) {
+    return const_cast<IntegerRelation *>(
+        reinterpret_cast<const IntegerRelation *>(pointer));
+  }
+
   void print(raw_ostream &os) const;
   void dump() const;
 
diff --git a/mlir/include/mlir/CAPI/Presburger.h b/mlir/include/mlir/CAPI/Presburger.h
new file mode 100644
index 00000000000000..cc88d13b638928
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Presburger.h
@@ -0,0 +1,29 @@
+#ifndef MLIR_CAPI_PRESBURGER_H
+#define MLIR_CAPI_PRESBURGER_H
+
+#include "mlir-c/Presburger.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
+#include "mlir/CAPI/Wrap.h"
+#include "llvm/ADT/DynamicAPInt.h"
+
+DEFINE_C_API_PTR_METHODS(MlirPresburgerIntegerRelation,
+                         mlir::presburger::IntegerRelation)
+
+static inline MlirPresburgerDynamicAPInt wrap(llvm::DynamicAPInt *cpp) {
+  return MlirPresburgerDynamicAPInt{cpp->getAsOpaquePointer()};
+}
+
+static inline llvm::DynamicAPInt *unwrap(MlirPresburgerDynamicAPInt c) {
+  return llvm::DynamicAPInt::getFromOpaquePointer(c.ptr);
+}
+
+static inline MlirPresburgerVariableKind wrap(mlir::presburger::VarKind var) {
+  return static_cast<MlirPresburgerVariableKind>(var);
+}
+
+static inline mlir::presburger::VarKind unwarp(MlirPresburgerVariableKind var) {
+  return static_cast<mlir::presburger::VarKind>(var);
+}
+
+#endif /* MLIR_CAPI_PRESBURGER_H */
\ No newline at end of file
diff --git a/mlir/lib/Bindings/Python/Presburger.cpp b/mlir/lib/Bindings/Python/Presburger.cpp
new file mode 100644
index 00000000000000..b652f75439efa5
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Presburger.cpp
@@ -0,0 +1,244 @@
+#include "mlir-c/Presburger.h"
+#include "PybindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <pybind11/attr.h>
+#include <pybind11/pybind11.h>
+#include <stdexcept>
+
+namespace py = pybind11;
+
+static bool isSignedIntegerFormat(std::string_view format) {
+  if (format.empty())
+    return false;
+  char code = format[0];
+  return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+         code == 'q';
+}
+
+namespace {
+struct PyPresburgerIntegerRelation {
+  PyPresburgerIntegerRelation(MlirPresburgerIntegerRelation relation)
+      : relation(relation) {}
+  PyPresburgerIntegerRelation(PyPresburgerIntegerRelation &&other) noexcept
+      : relation(other.relation) {
+    other.relation.ptr = nullptr;
+  }
+  ~PyPresburgerIntegerRelation() {
+    if (relation.ptr) {
+      mlirPresburgerIntegerRelationDestroy(relation);
+      relation.ptr = {nullptr};
+    }
+  }
+  static PyPresburgerIntegerRelation
+  getFromBuffers(py::buffer inequalitiesCoefficients,
+                 py::buffer equalityCoefficients, unsigned numDomainVars,
+                 unsigned numRangeVars);
+  py::object getCapsule();
+  int64_t getNumConstraints() {
+    return mlirPresburgerIntegerRelationNumConstraints(relation);
+  }
+  int64_t getNumInequalities() {
+    return mlirPresburgerIntegerRelationNumInequalities(relation);
+  }
+  int64_t getNumEqualities() {
+    return mlirPresburgerIntegerRelationNumEqualities(relation);
+  }
+  int64_t getNumDomainVars() {
+    return mlirPresburgerIntegerRelationNumDomainVars(relation);
+  }
+  int64_t getNumRangeVars() {
+    return mlirPresburgerIntegerRelationNumRangeVars(relation);
+  }
+  int64_t getNumSymbolVars() {
+    return mlirPresburgerIntegerRelationNumSymbolVars(relation);
+  }
+  int64_t getNumLocalVars() {
+    return mlirPresburgerIntegerRelationNumLocalVars(relation);
+  }
+  int64_t getNumCols() {
+    return mlirPresburgerIntegerRelationNumCols(relation);
+  }
+  MlirPresburgerIntegerRelation relation{nullptr};
+};
+
+/// A utility that enables accessing/modifying the underlying coefficients
+/// easier.
+struct PyPresburgerTableau {
+  enum class Kind { Equalities, Inequalities };
+  PyPresburgerTableau(MlirPresburgerIntegerRelation relation, Kind kind)
+      : relation(relation), kind(kind) {}
+  static void bind(py::module &module);
+  int64_t at64(int64_t row, int64_t col) const {
+    if (kind == Kind::Equalities)
+      return mlirPresburgerIntegerRelationAtEq64(relation, row, col);
+    return mlirPresburgerIntegerRelationAtIneq64(relation, row, col);
+  }
+  MlirPresburgerIntegerRelation relation;
+  Kind kind;
+};
+} // namespace
+
+PyPresburgerIntegerRelation PyPresburgerIntegerRelation::getFromBuffers(
+    py::buffer inequalitiesCoefficients, py::buffer equalityCoefficients,
+    unsigned numDomainVars, unsigned numRangeVars) {
+  // Request a contiguous view. In exotic cases, this will cause a copy.
+  int flags = PyBUF_ND;
+  flags |= PyBUF_FORMAT;
+  // Get the view of the inequality coefficients.
+  std::unique_ptr<Py_buffer> ineqView = std::make_unique<Py_buffer>();
+  if (PyObject_GetBuffer(inequalitiesCoefficients.ptr(), ineqView.get(),
+                         flags) != 0)
+    throw py::error_already_set();
+  auto freeIneqBuffer = llvm::make_scope_exit([&]() {
+    if (ineqView)
+      PyBuffer_Release(ineqView.get());
+  });
+  if (!PyBuffer_IsContiguous(ineqView.get(), 'A'))
+    throw std::invalid_argument("Contiguous buffer is required.");
+  if (!isSignedIntegerFormat(ineqView->format) || ineqView->itemsize != 8)
+    throw std::invalid_argument(
+        std::string("IntegerRelation can only be created from a buffer of "
+                    "i64 values but got buffer with format: ") +
+        std::string(ineqView->format));
+  if (ineqView->ndim != 2)
+    throw std::invalid_argument(
+        std::string("expected 2d inequality coefficients but got rank ") +
+        std::to_string(ineqView->ndim));
+  unsigned numInequalities = ineqView->shape[0];
+  // Get the view of the eequality coefficients.
+  std::unique_ptr<Py_buffer> eqView = std::make_unique<Py_buffer>();
+  if (PyObject_GetBuffer(equalityCoefficients.ptr(), eqView.get(), flags) != 0)
+    throw py::error_already_set();
+  auto freeEqBuffer = llvm::make_scope_exit([&]() {
+    if (eqView)
+      PyBuffer_Release(eqView.get());
+  });
+  if (!PyBuffer_IsContiguous(eqView.get(), 'A'))
+    throw std::invalid_argument("Contiguous buffer is required.");
+  if (!isSignedIntegerFormat(eqView->format) || eqView->itemsize != 8)
+    throw std::invalid_argument(
+        std::string("IntegerRelation can only be created from a buffer of "
+                    "i64 values but got buffer with format: ") +
+        std::string(eqView->format));
+  if (eqView->ndim != 2)
+    throw std::invalid_argument(
+        std::string("expected 2d equality coefficients but got rank ") +
+        std::to_string(eqView->ndim));
+  unsigned numEqualities = eqView->shape[0];
+  if (eqView->shape[1] != numDomainVars + numRangeVars + 1 ||
+      eqView->shape[1] != ineqView->shape[1])
+    throw std::invalid_argument(
+        "expected number of columns of inequality and equality coefficient "
+        "matrices to equal numRangeVars + numDomainVars + 1");
+  MlirPresburgerIntegerRelation relation =
+      mlirPresburgerIntegerRelationCreateFromCoefficients(
+          reinterpret_cast<const int64_t *>(ineqView->buf), numInequalities,
+          reinterpret_cast<const int64_t *>(eqView->buf), numEqualities,
+          numDomainVars, numRangeVars);
+  return PyPresburgerIntegerRelation(relation);
+}
+
+py::object PyPresburgerIntegerRelation::getCapsule() {
+  throw std::invalid_argument("unimplemented");
+}
+
+void PyPresburgerTableau::bind(py::module &m) {
+  py::class_<PyPresburgerTableau>(m, "IntegerRelationTableau",
+                                  py::module_local())
+      .def("__getitem__", [](PyPresburgerTableau &tableau,
+                             const py::tuple &index) {
+        return tableau.at64(index[0].cast<int64_t>(), index[1].cast<int64_t>());
+      });
+}
+
+static void populatePresburgerModule(py::module &m) {
+  PyPresburgerTableau::bind(m);
+  py::class_<PyPresburgerIntegerRelation>(m, "IntegerRelation",
+                                          py::module_local())
+      .def(py::init<>(&PyPresburgerIntegerRelation::getFromBuffers))
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyPresburgerIntegerRelation::getCapsule)
+      .def("__eq__",
+           [](PyPresburgerIntegerRelation &relation,
+              PyPresburgerIntegerRelation &other) {
+             return mlirPresburgerIntegerRelationIsEqual(relation.relation,
+                                                         other.relation);
+           })
+      .def(
+          "intersect",
+          [](PyPresburgerIntegerRelation &relation,
+             PyPresburgerIntegerRelation &other) {
+            PyPresburgerIntegerRelation intersection(
+                mlirPresburgerIntegerRelationIntersect(relation.relation,
+                                                       other.relation));
+            return intersection;
+          },
+          py::keep_alive<0, 1>())
+      .def(
+          "inequalities",
+          [](PyPresburgerIntegerRelation &relation) {
+            PyPresburgerTableau tableau(
+                relation.relation, PyPresburgerTableau::Kind::Inequalities);
+            return tableau;
+          },
+          py::keep_alive<0, 1>())
+      .def(
+          "equalities",
+          [](PyPresburgerIntegerRelation &relation) {
+            PyPresburgerTableau tableau(relation.relation,
+                                        PyPresburgerTableau::Kind::Equalities);
+            return tableau;
+          },
+          py::keep_alive<0, 1>())
+      .def("get_equality",
+           [](PyPresburgerIntegerRelation &relation, int64_t row) {
+             unsigned numCol =
+                 mlirPresburgerIntegerRelationNumCols(relation.relation);
+             std::vector<int64_t> result(numCol);
+             for (unsigned i = 0; i < numCol; i++)
+               result[i] = mlirPresburgerIntegerRelationAtEq64(
+                   relation.relation, row, i);
+             return result;
+           })
+      .def("get_inequality",
+           [](PyPresburgerIntegerRelation &relation, int64_t row) {
+             unsigned numCol =
+                 mlirPresburgerIntegerRelationNumCols(relation.relation);
+             std::vector<int64_t> result(numCol);
+             for (unsigned i = 0; i < numCol; i++)
+               result[i] = mlirPresburgerIntegerRelationAtIneq64(
+                   relation.relation, row, i);
+             return result;
+           })
+      .def_property_readonly("num_constraints",
+                             &PyPresburgerIntegerRelation::getNumConstraints)
+      .def_property_readonly("num_equalities",
+                             &PyPresburgerIntegerRelation::getNumEqualities)
+      .def_property_readonly("num_inequalities",
+                             &PyPresburgerIntegerRelation::getNumInequalities)
+      .def_property_readonly("num_domain_vars",
+                             &PyPresburgerIntegerRelation::getNumDomainVars)
+      .def_property_readonly("num_range_vars",
+                             &PyPresburgerIntegerRelation::getNumRangeVars)
+      .def_property_readonly("num_symbol_vars",
+                             &PyPresburgerIntegerRelation::getNumSymbolVars)
+      .def_property_readonly("num_local_vars",
+                             &PyPresburgerIntegerRelation::getNumLocalVars)
+      .def_property_readonly("num_columns",
+                             &PyPresburgerIntegerRelation::getNumCols)
+      .def("__str__", [](PyPresburgerIntegerRelation &relation) {
+        mlirPresburgerIntegerRelationDump(relation.relation);
+        return "";
+      });
+}
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+PYBIND11_MODULE(_mlirPresburger, m) {
+  m.doc() = "MLIR Presburger utilities";
+  populatePresburgerModule(m);
+}
\ No newline at end of file
diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt
index 6c438508425b7c..56888798e92292 100644
--- a/mlir/lib/CAPI/CMakeLists.txt
+++ b/mlir/lib/CAPI/CMakeLists.txt
@@ -15,6 +15,7 @@ add_subdirectory(IR)
 add_subdirectory(RegisterEverything)
 add_subdirectory(Transforms)
 add_subdirectory(Target)
+add_subdirectory(Presburger)
 
 if(MLIR_ENABLE_EXECUTION_ENGINE)
   add_subdirectory(ExecutionEngine)
diff --git a/mlir/lib/CAPI/Presburger/CMakeLists.txt b/mlir/lib/CAPI/Presburger/CMakeLists.txt
new file mode 100644
index 00000000000000..956006233dda5b
--- /dev/null
+++ b/mlir/lib/CAPI/Presburger/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_upstream_c_api_library(MLIRCAPIPresburger
+  Presburger.cpp
+  
+  LINK_LIBS PUBLIC
+  MLIRPresburger
+  )
\ No newline at end of file
diff --git a/mlir/lib/CAPI/Presburger/Presburger.cpp b/mlir/lib/CAPI/Presburger/Presburger.cpp
new file mode 100644
index 00000000000000..e07b56ea291312
--- /dev/null
+++ b/mlir/lib/CAPI/Presburger/Presburger.cpp
@@ -0,0 +1,163 @@
+#include "mlir/CAPI/Presburger.h"
+#include "mlir-c/Presburger.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/IR/Region.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Debug.h"
+using namespace mlir;
+using namespace mlir::presburger;
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation creation/destruction and basic metadata operations
+//===----------------------------------------------------------------------===//
+
+MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationCreate(unsigned numReservedInequalities,
+                                    unsigned numReservedEqualities,
+                                    unsigned numReservedCols) {
+  auto space = PresburgerSpace::getRelationSpace();
+  IntegerRelation *relation = new IntegerRelation(
+      numReservedInequalities, numReservedEqualities, numReservedCols, space);
+  return wrap(relation);
+}
+
+MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationCreateFromCoefficients(
+    const int64_t *inequalityCoefficients, unsigned numInequalities,
+    const int64_t *equalityCoefficients, unsigned numEqualities,
+    unsigned numDomainVars, unsigned numRangeVars,
+    unsigned numExtraReservedInequalities, unsigned numExtraReservedEqualities,
+    unsigned numExtraReservedCols) {
+  auto space = PresburgerSpace::getRelationSpace(numDomainVars, numRangeVars);
+  IntegerRelation *relation =
+      new IntegerRelation(numInequalities + numExtraReservedInequalities,
+                          numEqualities + numExtraReservedInequalities,
+                          numDomainVars + numRangeVars + 1, space);
+  unsigned numCols = numRangeVars + numDomainVars + 1;
+  for (const int64_t *rowPtr = inequalityCoefficients;
+       rowPtr < inequalityCoefficients + numCols * numInequalities;
+       rowPtr += numCols) {
+    llvm::ArrayRef<int64_t> coef(rowPtr, rowPtr + numCols);
+    relation->addInequality(coef);
+  }
+  for (const int64_t *rowPtr = equalityCoefficients;
+       rowPtr < equalityCoefficients + numCols * numEqualities;
+       rowPtr += numCols) {
+    llvm::ArrayRef<int64_t> coef(rowPtr, rowPtr + numCols);
+    relation->addEquality(coef);
+  }
+  return wrap(relation);
+}
+
+void mlirPresburgerIntegerRelationDestroy(
+    MlirPresburgerIntegerRelation relation) {
+  if (relation.ptr)
+    delete reinterpret_cast<IntegerRelation *>(relation.ptr);
+}
+
+unsigned mlirPresburgerIntegerRelationNumConstraints(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumConstraints();
+}
+
+unsigned mlirPresburgerIntegerRelationNumInequalities(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumInequalities();
+}
+
+unsigned mlirPresburgerIntegerRelationNumEqualities(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumEqualities();
+}
+
+unsigned mlirPresburgerIntegerRelationNumDomainVars(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumDomainVars();
+}
+
+unsigned mlirPresburgerIntegerRelationNumRangeVars(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumRangeVars();
+}
+
+unsigned mlirPresburgerIntegerRelationNumSymbolVars(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumSymbolVars();
+}
+
+unsigned mlirPresburgerIntegerRelationNumLocalVars(
+    MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumLocalVars();
+}
+
+unsigned
+mlirPresburgerIntegerRelationNumCols(MlirPresburgerIntegerRelation relation) {
+  return unwrap(relation)->getNumCols();
+}
+
+MlirPresburgerVariableKind mlirPresburgerIntegerRelationGetVarKindAt(
+    MlirPresburgerIntegerRelation relation, unsigned pos) {
+  return wrap(unwrap(relation)->getVarKindAt(pos));
+}
+
+void mlirPresburgerIntegerRelationDump(MlirPresburgerIntegerRelation relation) {
+  unwrap(relation)->dump();
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation binary operations
+//===----------------------------------------------------------------------===//
+
+bool mlirPresburgerIntegerRelationIsEqual(MlirPresburgerIntegerRelation lhs,
+                                          MlirPresburgerIntegerRelation rhs) {
+  return unwrap(lhs)->isEqual(*(unwrap(rhs)));
+}
+
+MlirPresburgerIntegerRelation
+mlirPresburgerIntegerRelationIntersect(MlirPresburgerIntegerRelation lhs,
+                                       MlirPresburgerIntegerRelation rhs) {
+  IntegerRelation intersect = unwrap(lhs)->intersect(*(unwrap(rhs)));
+  return wrap(&intersect);
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerRelation Tableau Inspection and Manipulation
+//===----------------------------------------------------------------------===//
+
+MlirPresburgerDynamicAPInt
+mlirPresburgerIntegerRelationAtEq(MlirPresburgerIntegerRelation relation,
+                                  unsigned i, unsigned j) {
+  return wrap(&unwrap(relation)->atEq(i, j));
+}
+
+int64_t
+mlirPresburgerIntegerRelationAtEq64(MlirPresburgerIntegerRelation relation,
+                                    unsigned row, unsigned col) {
+  return unwrap(relation)->atEq64(row, col);
+}
+
+MlirPresburgerDynamicAPInt
+mlirPresburgerIntegerRelationAtIneq(MlirPresburgerIntegerRelation relation,
+                                    unsigned row, unsigned col) {
+  return wrap(&unwrap(relation)->atIneq(row, col));
+}
+
+int64_t
+mlirPresburgerIntegerRelationAtIneq64(MlirPresburgerIntegerRelation relation,
+                                      unsigned row, unsigned col) {
+  return unwrap(relation)->atIneq64(row, col);
+}
+
+void mlirPresburgerIntegerRelationAddEquality(
+    MlirPresburgerIntegerRelation relation, const int64_t *coefficients,
+    size_t coefficientsSize) {
+  unwrap(relation)->addEquality(
+      llvm::ArrayRef<int64_t>(coefficients, coefficients + coefficientsSize));
+}
+
+void mlirPresburgerIntegerRelationAddInequality(
+    MlirPresburgerIntegerRelation relation, const int64_t *coefficients,
+    size_t coefficientsSize) {
+  unwrap(relation)->addEquality(
+      llvm::ArrayRef<int64_t>(coefficients, coefficients + coefficientsSize));
+}
\ No newline at end of file
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..f47cbddc7e371f 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -48,6 +48,13 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
     runtime/*.py
 )
 
+declare_mlir_python_sources(MLIRPythonSources.Presburger
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  ADD_TO_PARENT MLIRPythonSources
+  SOURCES
+    presburger.py
+  )
+
 declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
   ROOT_DIR "${MLIR_SOURCE_DIR}/include"
   SOURCES_GLOB "mlir-c/*.h"
@@ -666,6 +673,18 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
     MLIRCAPITransformDialectTransforms
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Presburger
+  MODULE_NAME _mlirPresburger
+  ADD_TO_PARENT MLIRPythonSources.Presburger
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    Presburger.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIPresburger
+)
+
 # TODO: Figure out how to put this in the test tree.
 # This should not be included in the main Python extension. However,
 # putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/presburger.py b/mlir/python/mlir/presburger.py
new file mode 100644
index 00000000000000..85e57eb0d2ee86
--- /dev/null
+++ b/mlir/python/mlir/presburger.py
@@ -0,0 +1,12 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Simply a wrapper around the extension module of the same name.
+from ._mlir_libs import _mlirPresburger as _presburger
+
+__all__ = ["IntegerRelation"]
+
+
+class IntegerRelation(_presburger.IntegerRelation):
+    pass
\ No newline at end of file
diff --git a/mlir/test/python/presburger.py b/mlir/test/python/presburger.py
new file mode 100644
index 00000000000000..cf5dcff0ebdddd
--- /dev/null
+++ b/mlir/test/python/presburger.py
@@ -0,0 +1,47 @@
+from mlir import presburger
+import numpy as np
+
+"""
+Test the following integer relation
+
+x + 2y = 8
+x - y <= 1
+y >= 3
+"""
+eqs = np.asarray([[1, 2, -8]], dtype=np.int64)
+ineqs = np.asarray([[1, -1, -1], [0, -1, 1]], dtype=np.int64)
+relation = presburger.IntegerRelation(ineqs, eqs, 2, 0)
+print(relation)
+print(relation.num_constraints)
+print(relation.num_inequalities)
+print(relation.num_equalities)
+print(relation.num_domain_vars)
+print(relation.num_range_vars)
+print(relation.num_symbol_vars)
+print(relation.num_local_vars)
+print(relation.num_columns)
+
+eq_first_row = relation.get_equality(0)
+print(eq_first_row)
+ineq_second_row = relation.get_inequality(1)
+print(ineq_second_row)
+
+eq_coefficients = relation.equalities()
+print(eq_coefficients[0, 1])
+ineq_coefficients = relation.inequalities()
+print(ineq_coefficients[1, 1])
+
+"""
+Test intersection
+
+Relation A
+
+x + y <= 10
+x >= 0
+y >= 0
+
+Relation B
+
+2x + y <= 12
+y >= 2
+"""
\ No newline at end of file



More information about the Mlir-commits mailing list