[Mlir-commits] [mlir] f08fe1f - [MLIR][Presburger] Implement matrix inverse (#67382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 09:03:49 PDT 2023


Author: Abhinav271828
Date: 2023-10-20T17:03:45+01:00
New Revision: f08fe1f1dd64d754064d1094704ee2938c25c325

URL: https://github.com/llvm/llvm-project/commit/f08fe1f1dd64d754064d1094704ee2938c25c325
DIFF: https://github.com/llvm/llvm-project/commit/f08fe1f1dd64d754064d1094704ee2938c25c325.diff

LOG: [MLIR][Presburger] Implement matrix inverse (#67382)

Shift the `determinant()` function from LinearTransform to Matrix.
Implement a FracMatrix class, inheriting from Matrix<Fraction>, for inverses.
Implement inverse for FracMatrix and intInverse for IntMatrix.
Make Matrix internals protected instead of private so that Int/FracMatrix can access them.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/LinearTransform.h
    mlir/include/mlir/Analysis/Presburger/Matrix.h
    mlir/lib/Analysis/Presburger/Matrix.cpp
    mlir/unittests/Analysis/Presburger/MatrixTest.cpp
    mlir/unittests/Analysis/Presburger/Utils.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
index 67dc9e87facb9ad..b5c761439f0b7e6 100644
--- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
+++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
@@ -50,10 +50,6 @@ class LinearTransform {
     return matrix.postMultiplyWithColumn(colVec);
   }
 
-  // Compute the determinant of the transform by converting it to row echelon
-  // form and then taking the product of the diagonal.
-  MPInt determinant();
-
 private:
   IntMatrix matrix;
 };

diff  --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 29f8b7d2b304e93..4d9f13832e0692a 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -189,7 +189,7 @@ class Matrix {
   /// invariants satisfied.
   bool hasConsistentState() const;
 
-private:
+protected:
   /// The current number of rows, columns, and reserved columns. The underlying
   /// data vector is viewed as an nRows x nReservedColumns matrix, of which the
   /// first nColumns columns are currently in use, and the remaining are
@@ -210,13 +210,7 @@ class IntMatrix : public Matrix<MPInt> {
             unsigned reservedColumns = 0)
       : Matrix<MPInt>(rows, columns, reservedRows, reservedColumns){};
 
-  IntMatrix(Matrix<MPInt> m)
-      : Matrix<MPInt>(m.getNumRows(), m.getNumColumns(), m.getNumReservedRows(),
-                      m.getNumReservedColumns()) {
-    for (unsigned i = 0; i < m.getNumRows(); i++)
-      for (unsigned j = 0; j < m.getNumColumns(); j++)
-        at(i, j) = m(i, j);
-  };
+  IntMatrix(Matrix<MPInt> m) : Matrix<MPInt>(std::move(m)){};
 
   /// Return the identity matrix of the specified dimension.
   static IntMatrix identity(unsigned dimension);
@@ -239,6 +233,38 @@ class IntMatrix : public Matrix<MPInt> {
   /// Divide the columns of the specified row by their GCD.
   /// Returns the GCD of the columns of the specified row.
   MPInt normalizeRow(unsigned row);
+
+  // Compute the determinant of the matrix (cubic time).
+  // Stores the integer inverse of the matrix in the pointer
+  // passed (if any). The pointer is unchanged if the inverse
+  // does not exist, which happens iff det = 0.
+  // For a matrix M, the integer inverse is the matrix M' such that
+  // M x M' = M'  M = det(M) x I.
+  // Assert-fails if the matrix is not square.
+  MPInt determinant(IntMatrix *inverse = nullptr) const;
+};
+
+// An inherited class for rational matrices, with no new data attributes.
+// This class is for functionality that only applies to matrices of fractions.
+class FracMatrix : public Matrix<Fraction> {
+public:
+  FracMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0,
+             unsigned reservedColumns = 0)
+      : Matrix<Fraction>(rows, columns, reservedRows, reservedColumns){};
+
+  FracMatrix(Matrix<Fraction> m) : Matrix<Fraction>(std::move(m)){};
+
+  explicit FracMatrix(IntMatrix m);
+
+  /// Return the identity matrix of the specified dimension.
+  static FracMatrix identity(unsigned dimension);
+
+  // Compute the determinant of the matrix (cubic time).
+  // Stores the inverse of the matrix in the pointer
+  // passed (if any). The pointer is unchanged if the inverse
+  // does not exist, which happens iff det = 0.
+  // Assert-fails if the matrix is not square.
+  Fraction determinant(FracMatrix *inverse = nullptr) const;
 };
 
 } // namespace presburger

diff  --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index ce6253e0bda93a2..ae97e456d9820cf 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -432,4 +432,120 @@ MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
 
 MPInt IntMatrix::normalizeRow(unsigned row) {
   return normalizeRow(row, getNumColumns());
+}
+
+MPInt IntMatrix::determinant(IntMatrix *inverse) const {
+  assert(nRows == nColumns &&
+         "determinant can only be calculated for square matrices!");
+
+  FracMatrix m(*this);
+
+  FracMatrix fracInverse(nRows, nColumns);
+  MPInt detM = m.determinant(&fracInverse).getAsInteger();
+
+  if (detM == 0)
+    return MPInt(0);
+
+  *inverse = IntMatrix(nRows, nColumns);
+  for (unsigned i = 0; i < nRows; i++)
+    for (unsigned j = 0; j < nColumns; j++)
+      inverse->at(i, j) = (fracInverse.at(i, j) * detM).getAsInteger();
+
+  return detM;
+}
+
+FracMatrix FracMatrix::identity(unsigned dimension) {
+  return Matrix::identity(dimension);
+}
+
+FracMatrix::FracMatrix(IntMatrix m)
+    : FracMatrix(m.getNumRows(), m.getNumColumns()) {
+  for (unsigned i = 0; i < m.getNumRows(); i++)
+    for (unsigned j = 0; j < m.getNumColumns(); j++)
+      this->at(i, j) = m.at(i, j);
+}
+
+Fraction FracMatrix::determinant(FracMatrix *inverse) const {
+  assert(nRows == nColumns &&
+         "determinant can only be calculated for square matrices!");
+
+  FracMatrix m(*this);
+  FracMatrix tempInv(nRows, nColumns);
+  if (inverse)
+    tempInv = FracMatrix::identity(nRows);
+
+  Fraction a, b;
+  // Make the matrix into upper triangular form using
+  // gaussian elimination with row operations.
+  // If inverse is required, we apply more operations
+  // to turn the matrix into diagonal form. We apply
+  // the same operations to the inverse matrix,
+  // which is initially identity.
+  // Either way, the product of the diagonal elements
+  // is then the determinant.
+  for (unsigned i = 0; i < nRows; i++) {
+    if (m(i, i) == 0)
+      // First ensure that the diagonal
+      // element is nonzero, by swapping
+      // it with a nonzero row.
+      for (unsigned j = i + 1; j < nRows; j++) {
+        if (m(j, i) != 0) {
+          m.swapRows(j, i);
+          if (inverse)
+            tempInv.swapRows(j, i);
+          break;
+        }
+      }
+
+    b = m.at(i, i);
+    if (b == 0)
+      return 0;
+
+    // Set all elements above the
+    // diagonal to zero.
+    if (inverse) {
+      for (unsigned j = 0; j < i; j++) {
+        if (m.at(j, i) == 0)
+          continue;
+        a = m.at(j, i);
+        // Set element (j, i) to zero
+        // by subtracting the ith row,
+        // appropriately scaled.
+        m.addToRow(i, j, -a / b);
+        tempInv.addToRow(i, j, -a / b);
+      }
+    }
+
+    // Set all elements below the
+    // diagonal to zero.
+    for (unsigned j = i + 1; j < nRows; j++) {
+      if (m.at(j, i) == 0)
+        continue;
+      a = m.at(j, i);
+      // Set element (j, i) to zero
+      // by subtracting the ith row,
+      // appropriately scaled.
+      m.addToRow(i, j, -a / b);
+      if (inverse)
+        tempInv.addToRow(i, j, -a / b);
+    }
+  }
+
+  // Now only diagonal elements of m are nonzero, but they are
+  // not necessarily 1. To get the true inverse, we should
+  // normalize them and apply the same scale to the inverse matrix.
+  // For efficiency we skip scaling m and just scale tempInv appropriately.
+  if (inverse) {
+    for (unsigned i = 0; i < nRows; i++)
+      for (unsigned j = 0; j < nRows; j++)
+        tempInv.at(i, j) = tempInv.at(i, j) / m(i, i);
+
+    *inverse = std::move(tempInv);
+  }
+
+  Fraction determinant = 1;
+  for (unsigned i = 0; i < nRows; i++)
+    determinant *= m.at(i, i);
+
+  return determinant;
 }
\ No newline at end of file

diff  --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 6b23cedabf624ec..d05b05e004c5c5f 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Matrix.h"
-#include "mlir/Analysis/Presburger/Fraction.h"
 #include "./Utils.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
@@ -210,7 +210,8 @@ TEST(MatrixTest, computeHermiteNormalForm) {
   {
     // Hermite form of a unimodular matrix is the identity matrix.
     IntMatrix mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
-    IntMatrix hermiteForm = makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
+    IntMatrix hermiteForm =
+        makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
@@ -241,10 +242,71 @@ TEST(MatrixTest, computeHermiteNormalForm) {
   }
 
   {
-    IntMatrix mat =
-        makeIntMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
-    IntMatrix hermiteForm =
-        makeIntMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
+    IntMatrix mat = makeIntMatrix(
+        3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
+    IntMatrix hermiteForm = makeIntMatrix(
+        3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 }
+
+TEST(MatrixTest, inverse) {
+  FracMatrix mat = makeFracMatrix(
+      2, 2, {{Fraction(2), Fraction(1)}, {Fraction(7), Fraction(0)}});
+  FracMatrix inverse = makeFracMatrix(
+      2, 2, {{Fraction(0), Fraction(1, 7)}, {Fraction(1), Fraction(-2, 7)}});
+
+  FracMatrix inv(2, 2);
+  mat.determinant(&inv);
+
+  EXPECT_EQ_FRAC_MATRIX(inv, inverse);
+
+  mat = makeFracMatrix(
+      2, 2, {{Fraction(0), Fraction(1)}, {Fraction(0), Fraction(2)}});
+  Fraction det = mat.determinant(nullptr);
+
+  EXPECT_EQ(det, Fraction(0));
+
+  mat = makeFracMatrix(3, 3,
+                       {{Fraction(1), Fraction(2), Fraction(3)},
+                        {Fraction(4), Fraction(8), Fraction(6)},
+                        {Fraction(7), Fraction(8), Fraction(6)}});
+  inverse = makeFracMatrix(3, 3,
+                           {{Fraction(0), Fraction(-1, 3), Fraction(1, 3)},
+                            {Fraction(-1, 2), Fraction(5, 12), Fraction(-1, 6)},
+                            {Fraction(2, 3), Fraction(-1, 6), Fraction(0)}});
+
+  mat.determinant(&inv);
+  EXPECT_EQ_FRAC_MATRIX(inv, inverse);
+
+  mat = makeFracMatrix(0, 0, {});
+  mat.determinant(&inv);
+}
+
+TEST(MatrixTest, intInverse) {
+  IntMatrix mat = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
+  IntMatrix inverse = makeIntMatrix(2, 2, {{0, -1}, {-7, 2}});
+
+  IntMatrix inv(2, 2);
+  mat.determinant(&inv);
+
+  EXPECT_EQ_INT_MATRIX(inv, inverse);
+
+  mat = makeIntMatrix(
+      4, 4, {{4, 14, 11, 3}, {13, 5, 14, 12}, {13, 9, 7, 14}, {2, 3, 12, 7}});
+  inverse = makeIntMatrix(4, 4,
+                          {{155, 1636, -579, -1713},
+                           {725, -743, 537, -111},
+                           {210, 735, -855, 360},
+                           {-715, -1409, 1401, 1482}});
+
+  mat.determinant(&inv);
+
+  EXPECT_EQ_INT_MATRIX(inv, inverse);
+
+  mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}});
+
+  MPInt det = mat.determinant(&inv);
+
+  EXPECT_EQ(det, 0);
+}

diff  --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index ef4a67d0b8c004f..544577375dd1d1c 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -14,10 +14,10 @@
 #define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H
 
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
 #include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
-#include "mlir/Analysis/Presburger/Matrix.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Support/LLVM.h"
 
@@ -28,7 +28,7 @@ namespace mlir {
 namespace presburger {
 
 inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns,
-                         ArrayRef<SmallVector<int, 8>> matrix) {
+                               ArrayRef<SmallVector<int, 8>> matrix) {
   IntMatrix results(numRow, numColumns);
   assert(matrix.size() == numRow);
   for (unsigned i = 0; i < numRow; ++i) {
@@ -40,9 +40,9 @@ inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns,
   return results;
 }
 
-inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
-                         ArrayRef<SmallVector<Fraction, 8>> matrix) {
-  Matrix<Fraction> results(numRow, numColumns);
+inline FracMatrix makeFracMatrix(unsigned numRow, unsigned numColumns,
+                                 ArrayRef<SmallVector<Fraction, 8>> matrix) {
+  FracMatrix results(numRow, numColumns);
   assert(matrix.size() == numRow);
   for (unsigned i = 0; i < numRow; ++i) {
     assert(matrix[i].size() == numColumns &&
@@ -53,6 +53,24 @@ inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
   return results;
 }
 
+inline void EXPECT_EQ_INT_MATRIX(IntMatrix a, IntMatrix b) {
+  EXPECT_EQ(a.getNumRows(), b.getNumRows());
+  EXPECT_EQ(a.getNumColumns(), b.getNumColumns());
+
+  for (unsigned row = 0; row < a.getNumRows(); row++)
+    for (unsigned col = 0; col < a.getNumColumns(); col++)
+      EXPECT_EQ(a(row, col), b(row, col));
+}
+
+inline void EXPECT_EQ_FRAC_MATRIX(FracMatrix a, FracMatrix b) {
+  EXPECT_EQ(a.getNumRows(), b.getNumRows());
+  EXPECT_EQ(a.getNumColumns(), b.getNumColumns());
+
+  for (unsigned row = 0; row < a.getNumRows(); row++)
+    for (unsigned col = 0; col < a.getNumColumns(); col++)
+      EXPECT_EQ(a(row, col), b(row, col));
+}
+
 /// lhs and rhs represent non-negative integers or positive infinity. The
 /// infinity case corresponds to when the Optional is empty.
 inline bool infinityOrUInt64LE(std::optional<MPInt> lhs,


        


More information about the Mlir-commits mailing list