[libcxx-commits] [libcxx] [MLIR][Presburger] Implement matrix inverse (PR #67382)

Arjun P via libcxx-commits libcxx-commits at lists.llvm.org
Fri Oct 13 03:36:35 PDT 2023


================
@@ -390,4 +463,83 @@ MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
 
 MPInt IntMatrix::normalizeRow(unsigned row) {
   return normalizeRow(row, getNumColumns());
+}
+
+std::optional<IntMatrix> IntMatrix::integerInverse() {
+  Fraction det = Fraction(determinant(), 1);
+  FracMatrix newMat(getNumRows(), getNumColumns());
+  for (unsigned i = 0; i < getNumRows(); i++)
+    for (unsigned j = 0; j < getNumColumns(); j++)
+      newMat(i, j) = Fraction(at(i, j), 1);
+
+  std::optional<FracMatrix> fracInverse = newMat.inverse();
+
+  if (!fracInverse)
+    return {};
+
+  IntMatrix intInverse(getNumRows(), getNumColumns());
+  for (unsigned i = 0; i < getNumRows(); i++)
+    for (unsigned j = 0; j < getNumColumns(); j++)
+      intInverse(i, j) = ((*fracInverse)(i, j) * det).getAsInteger();
+
+  return intInverse;
+}
+
+FracMatrix FracMatrix::identity(unsigned dimension) {
+  return Matrix::identity(dimension);
+}
+
+std::optional<FracMatrix> FracMatrix::inverse() {
+  // We use Gaussian elimination on the rows of [M | I]
+  // to find the integer inverse. We proceed left-to-right,
+  // top-to-bottom. M is assumed to be a dim x dim matrix.
+
+  unsigned dim = getNumRows();
+
+  // Construct the augmented matrix [M | I]
+  FracMatrix augmented(dim, dim + dim);
+  for (unsigned i = 0; i < dim; i++) {
+    augmented.fillRow(i, 0);
+    for (unsigned j = 0; j < dim; j++)
+      augmented(i, j) = at(i, j);
+    augmented(i, dim + i).num = 1;
+    augmented(i, dim + i).den = 1;
+  }
+  Fraction a, b;
+  for (unsigned i = 0; i < dim; i++) {
+    if (augmented(i, i) == Fraction(0, 1))
+      for (unsigned j = i + 1; j < dim; j++)
+        if (augmented(j, i) != Fraction(0, 1)) {
+          augmented.addToRow(i, augmented.getRow(j), Fraction(1, 1));
+          break;
+        }
+
+    b = augmented(i, i);
+    if (b == 0)
+      return {};
+    for (unsigned j = 0; j < dim; j++) {
+      if (i == j || augmented(j, i) == 0)
+        continue;
+      a = augmented(j, i);
+      // Rj -> Rj - (b/a)Ri
+      augmented.addToRow(j, augmented.getRow(i), -a / b);
+      // Now (Rj)i = 0
----------------
Superty wrote:

please write full sentences for documentation throughout and explain what each step is doing in this main loop.

https://github.com/llvm/llvm-project/pull/67382


More information about the libcxx-commits mailing list