[libcxx-commits] [libcxx] [MLIR][Presburger] Implement matrix inverse (PR #67382)
Arjun P via libcxx-commits
libcxx-commits at lists.llvm.org
Fri Oct 13 03:36:35 PDT 2023
================
@@ -390,4 +463,83 @@ MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
MPInt IntMatrix::normalizeRow(unsigned row) {
return normalizeRow(row, getNumColumns());
+}
+
+std::optional<IntMatrix> IntMatrix::integerInverse() {
+ Fraction det = Fraction(determinant(), 1);
+ FracMatrix newMat(getNumRows(), getNumColumns());
+ for (unsigned i = 0; i < getNumRows(); i++)
+ for (unsigned j = 0; j < getNumColumns(); j++)
+ newMat(i, j) = Fraction(at(i, j), 1);
+
+ std::optional<FracMatrix> fracInverse = newMat.inverse();
+
+ if (!fracInverse)
+ return {};
+
+ IntMatrix intInverse(getNumRows(), getNumColumns());
+ for (unsigned i = 0; i < getNumRows(); i++)
+ for (unsigned j = 0; j < getNumColumns(); j++)
+ intInverse(i, j) = ((*fracInverse)(i, j) * det).getAsInteger();
+
+ return intInverse;
+}
+
+FracMatrix FracMatrix::identity(unsigned dimension) {
+ return Matrix::identity(dimension);
+}
+
+std::optional<FracMatrix> FracMatrix::inverse() {
+ // We use Gaussian elimination on the rows of [M | I]
+ // to find the integer inverse. We proceed left-to-right,
+ // top-to-bottom. M is assumed to be a dim x dim matrix.
+
+ unsigned dim = getNumRows();
+
+ // Construct the augmented matrix [M | I]
+ FracMatrix augmented(dim, dim + dim);
+ for (unsigned i = 0; i < dim; i++) {
+ augmented.fillRow(i, 0);
+ for (unsigned j = 0; j < dim; j++)
+ augmented(i, j) = at(i, j);
+ augmented(i, dim + i).num = 1;
+ augmented(i, dim + i).den = 1;
+ }
+ Fraction a, b;
+ for (unsigned i = 0; i < dim; i++) {
+ if (augmented(i, i) == Fraction(0, 1))
+ for (unsigned j = i + 1; j < dim; j++)
+ if (augmented(j, i) != Fraction(0, 1)) {
+ augmented.addToRow(i, augmented.getRow(j), Fraction(1, 1));
+ break;
+ }
+
+ b = augmented(i, i);
+ if (b == 0)
+ return {};
+ for (unsigned j = 0; j < dim; j++) {
+ if (i == j || augmented(j, i) == 0)
+ continue;
+ a = augmented(j, i);
+ // Rj -> Rj - (b/a)Ri
+ augmented.addToRow(j, augmented.getRow(i), -a / b);
+ // Now (Rj)i = 0
----------------
Superty wrote:
please write full sentences for documentation throughout and explain what each step is doing in this main loop.
https://github.com/llvm/llvm-project/pull/67382
More information about the libcxx-commits
mailing list