[Mlir-commits] [mlir] [MLIR][Presburger] Template Matrix to allow MPInt and Fraaction (PR #65272)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 4 08:44:46 PDT 2023


https://github.com/Abhinav271828 created https://github.com/llvm/llvm-project/pull/65272:

Matrix has been templated to Matrix<T> (for MPInt and Fraction) with explicit instantiation for both these types.
makeMatrix has been duplicated to makeIntMatrix and makeFracMatrix.
In Fraction, we implement basic arithmetic operations to allow for Matrix operations.

>From afca1cf7cc38aa3a6e9dff95c9cd0aa634e21c9e Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Thu, 31 Aug 2023 17:30:17 +0100
Subject: [PATCH 1/4] Template Matrix to Matrix<T> (for MPInt and Fraction)
 with explicit instantiation Duplicate makeMatrix to makeIntMatrix and
 makeFracMatrix

Implement arithmetic operations for Fraction for compatibility
---
 .../mlir/Analysis/Presburger/Fraction.h       |  28 ++++-
 .../Analysis/Presburger/IntegerRelation.h     |   6 +-
 .../Analysis/Presburger/LinearTransform.h     |  12 +-
 .../include/mlir/Analysis/Presburger/Matrix.h |  42 +++----
 .../mlir/Analysis/Presburger/PWMAFunction.h   |   8 +-
 .../mlir/Analysis/Presburger/Simplex.h        |   4 +-
 mlir/include/mlir/Analysis/Presburger/Utils.h |   2 +-
 .../Analysis/FlatLinearValueConstraints.cpp   |   2 +-
 .../Analysis/Presburger/IntegerRelation.cpp   |   6 +-
 .../Analysis/Presburger/LinearTransform.cpp   |   6 +-
 mlir/lib/Analysis/Presburger/Matrix.cpp       | 114 ++++++++++--------
 mlir/lib/Analysis/Presburger/Simplex.cpp      |   4 +-
 .../Presburger/LinearTransformTest.cpp        |  14 +--
 .../Analysis/Presburger/MatrixTest.cpp        |  43 +++----
 mlir/unittests/Analysis/Presburger/Parser.h   |   2 +-
 mlir/unittests/Analysis/Presburger/Utils.h    |  20 ++-
 16 files changed, 182 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index c51b6c972bf8851..2cb90b708435353 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -15,6 +15,7 @@
 #define MLIR_ANALYSIS_PRESBURGER_FRACTION_H
 
 #include "mlir/Analysis/Presburger/MPInt.h"
+#include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Support/MathExtras.h"
 
 namespace mlir {
@@ -30,15 +31,15 @@ struct Fraction {
   Fraction() = default;
 
   /// Construct a Fraction from a numerator and denominator.
-  Fraction(const MPInt &oNum, const MPInt &oDen) : num(oNum), den(oDen) {
+  Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) {
     if (den < 0) {
       num = -num;
       den = -den;
     }
   }
   /// Overloads for passing literals.
-  Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {}
-  Fraction(int64_t num, const MPInt &den) : Fraction(MPInt(num), den) {}
+  Fraction(const MPInt &num, int64_t den = 1) : Fraction(num, MPInt(den)) {}
+  Fraction(int64_t num, const MPInt &den = MPInt(1)) : Fraction(MPInt(num), den) {}
   Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {}
 
   // Return the value of the fraction as an integer. This should only be called
@@ -48,6 +49,10 @@ struct Fraction {
     return num / den;
   }
 
+  llvm::raw_ostream &print(llvm::raw_ostream &os) const {
+    return os << "(" << num << "/" << den << ")";
+  }
+
   /// The numerator and denominator, respectively. The denominator is always
   /// positive.
   MPInt num{0}, den{1};
@@ -99,6 +104,23 @@ inline Fraction operator*(const Fraction &x, const Fraction &y) {
   return Fraction(x.num * y.num, x.den * y.den);
 }
 
+inline Fraction operator/(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den, x.den * y.num);
+}
+
+inline Fraction operator+(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den + x.den * y.num, x.den * y.den);
+}
+
+inline Fraction operator-(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den - x.den * y.num, x.den * y.den);
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Fraction &x) {
+  x.print(os);
+  return os;
+}
+
 } // namespace presburger
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 9646894736de069..eeffe58b4a63547 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -359,7 +359,7 @@ class IntegerRelation {
   /// bounded. The span of the returned vectors is guaranteed to contain all
   /// such vectors. The returned vectors are NOT guaranteed to be linearly
   /// independent. This function should not be called on empty sets.
-  Matrix getBoundedDirections() const;
+  Matrix<MPInt> getBoundedDirections() const;
 
   /// Find an integer sample point satisfying the constraints using a
   /// branch and bound algorithm with generalized basis reduction, with some
@@ -782,10 +782,10 @@ class IntegerRelation {
   PresburgerSpace space;
 
   /// Coefficients of affine equalities (in == 0 form).
-  Matrix equalities;
+  Matrix<MPInt> equalities;
 
   /// Coefficients of affine inequalities (in >= 0 form).
-  Matrix inequalities;
+  Matrix<MPInt> inequalities;
 };
 
 /// An IntegerPolyhedron represents the set of points from a PresburgerSpace
diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
index cd56951fe773f8b..686e846a16c78a7 100644
--- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
+++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
@@ -22,8 +22,8 @@ namespace presburger {
 
 class LinearTransform {
 public:
-  explicit LinearTransform(Matrix &&oMatrix);
-  explicit LinearTransform(const Matrix &oMatrix);
+  explicit LinearTransform(Matrix<MPInt> &&oMatrix);
+  explicit LinearTransform(const Matrix<MPInt> &oMatrix);
 
   // Returns a linear transform T such that MT is M in column echelon form.
   // Also returns the number of non-zero columns in MT.
@@ -32,7 +32,7 @@ class LinearTransform {
   // strictly below that of the previous column, and all columns which have only
   // zeros are at the end.
   static std::pair<unsigned, LinearTransform>
-  makeTransformToColumnEchelon(const Matrix &m);
+  makeTransformToColumnEchelon(const Matrix<MPInt> &m);
 
   // Returns an IntegerRelation having a constraint vector vT for every
   // constraint vector v in rel, where T is this transform.
@@ -50,8 +50,12 @@ class LinearTransform {
     return matrix.postMultiplyWithColumn(colVec);
   }
 
+  // Compute the determinant of the transform by converting it to row echelon
+  // form and then taking the product of the diagonal.
+  MPInt determinant();
+
 private:
-  Matrix matrix;
+  Matrix<MPInt> matrix;
 };
 
 } // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index bae1661d9ce6c60..b03737ab2f70a4a 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H
 #define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
 
-#include "mlir/Analysis/Presburger/MPInt.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
@@ -32,6 +31,7 @@ namespace presburger {
 /// (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused
 /// columns always have all zero values. The reserved rows are just reserved
 /// space in the underlying SmallVector's capacity.
+template<typename T>
 class Matrix {
 public:
   Matrix() = delete;
@@ -49,21 +49,21 @@ class Matrix {
   static Matrix identity(unsigned dimension);
 
   /// Access the element at the specified row and column.
-  MPInt &at(unsigned row, unsigned column) {
+  T &at(unsigned row, unsigned column) {
     assert(row < nRows && "Row outside of range");
     assert(column < nColumns && "Column outside of range");
     return data[row * nReservedColumns + column];
   }
 
-  MPInt at(unsigned row, unsigned column) const {
+  T at(unsigned row, unsigned column) const {
     assert(row < nRows && "Row outside of range");
     assert(column < nColumns && "Column outside of range");
     return data[row * nReservedColumns + column];
   }
 
-  MPInt &operator()(unsigned row, unsigned column) { return at(row, column); }
+  T &operator()(unsigned row, unsigned column) { return at(row, column); }
 
-  MPInt operator()(unsigned row, unsigned column) const {
+  T operator()(unsigned row, unsigned column) const {
     return at(row, column);
   }
 
@@ -87,11 +87,11 @@ class Matrix {
   void reserveRows(unsigned rows);
 
   /// Get a [Mutable]ArrayRef corresponding to the specified row.
-  MutableArrayRef<MPInt> getRow(unsigned row);
-  ArrayRef<MPInt> getRow(unsigned row) const;
+  MutableArrayRef<T> getRow(unsigned row);
+  ArrayRef<T> getRow(unsigned row) const;
 
   /// Set the specified row to `elems`.
-  void setRow(unsigned row, ArrayRef<MPInt> elems);
+  void setRow(unsigned row, ArrayRef<T> elems);
 
   /// Insert columns having positions pos, pos + 1, ... pos + count - 1.
   /// Columns that were at positions 0 to pos - 1 will stay where they are;
@@ -125,23 +125,23 @@ class Matrix {
 
   void copyRow(unsigned sourceRow, unsigned targetRow);
 
-  void fillRow(unsigned row, const MPInt &value);
-  void fillRow(unsigned row, int64_t value) { fillRow(row, MPInt(value)); }
+  void fillRow(unsigned row, const T &value);
+  void fillRow(unsigned row, int64_t value) { fillRow(row, T(value)); }
 
   /// Add `scale` multiples of the source row to the target row.
-  void addToRow(unsigned sourceRow, unsigned targetRow, const MPInt &scale);
+  void addToRow(unsigned sourceRow, unsigned targetRow, const T &scale);
   void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
-    addToRow(sourceRow, targetRow, MPInt(scale));
+    addToRow(sourceRow, targetRow, T(scale));
   }
   /// Add `scale` multiples of the rowVec row to the specified row.
-  void addToRow(unsigned row, ArrayRef<MPInt> rowVec, const MPInt &scale);
+  void addToRow(unsigned row, ArrayRef<T> rowVec, const T &scale);
 
   /// Add `scale` multiples of the source column to the target column.
   void addToColumn(unsigned sourceColumn, unsigned targetColumn,
-                   const MPInt &scale);
+                   const T &scale);
   void addToColumn(unsigned sourceColumn, unsigned targetColumn,
                    int64_t scale) {
-    addToColumn(sourceColumn, targetColumn, MPInt(scale));
+    addToColumn(sourceColumn, targetColumn, T(scale));
   }
 
   /// Negate the specified column.
@@ -152,18 +152,18 @@ class Matrix {
 
   /// Divide the first `nCols` of the specified row by their GCD.
   /// Returns the GCD of the first `nCols` of the specified row.
-  MPInt normalizeRow(unsigned row, unsigned nCols);
+  T normalizeRow(unsigned row, unsigned nCols);
   /// Divide the columns of the specified row by their GCD.
   /// Returns the GCD of the columns of the specified row.
-  MPInt normalizeRow(unsigned row);
+  T normalizeRow(unsigned row);
 
   /// The given vector is interpreted as a row vector v. Post-multiply v with
   /// this matrix, say M, and return vM.
-  SmallVector<MPInt, 8> preMultiplyWithRow(ArrayRef<MPInt> rowVec) const;
+  SmallVector<T, 8> preMultiplyWithRow(ArrayRef<T> rowVec) const;
 
   /// The given vector is interpreted as a column vector v. Pre-multiply v with
   /// this matrix, say M, and return Mv.
-  SmallVector<MPInt, 8> postMultiplyWithColumn(ArrayRef<MPInt> colVec) const;
+  SmallVector<T, 8> postMultiplyWithColumn(ArrayRef<T> colVec) const;
 
   /// Given the current matrix M, returns the matrices H, U such that H is the
   /// column hermite normal form of M, i.e. H = M * U, where U is unimodular and
@@ -192,7 +192,7 @@ class Matrix {
   unsigned appendExtraRow();
   /// Same as above, but copy the given elements into the row. The length of
   /// `elems` must be equal to the number of columns.
-  unsigned appendExtraRow(ArrayRef<MPInt> elems);
+  unsigned appendExtraRow(ArrayRef<T> elems);
 
   /// Print the matrix.
   void print(raw_ostream &os) const;
@@ -211,7 +211,7 @@ class Matrix {
 
   /// Stores the data. data.size() is equal to nRows * nReservedColumns.
   /// data.capacity() / nReservedColumns is the number of reserved rows.
-  SmallVector<MPInt, 16> data;
+  SmallVector<T, 16> data;
 };
 
 } // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index ea3456624e72d4e..0b3804fc08a60e0 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -40,13 +40,13 @@ enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
 /// value of the function at a specified point.
 class MultiAffineFunction {
 public:
-  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output)
       : space(space), output(output),
         divs(space.getNumVars() - space.getNumRangeVars()) {
     assertIsConsistent();
   }
 
-  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output,
                       const DivisionRepr &divs)
       : space(space), output(output), divs(divs) {
     assertIsConsistent();
@@ -65,7 +65,7 @@ class MultiAffineFunction {
   PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
 
   /// Get a matrix with each row representing row^th output expression.
-  const Matrix &getOutputMatrix() const { return output; }
+  const Matrix<MPInt> &getOutputMatrix() const { return output; }
   /// Get the `i^th` output expression.
   ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
@@ -124,7 +124,7 @@ class MultiAffineFunction {
   /// The function's output is a tuple of integers, with the ith element of the
   /// tuple defined by the affine expression given by the ith row of this output
   /// matrix.
-  Matrix output;
+  Matrix<MPInt> output;
 
   /// Storage for division representation for each local variable in space.
   DivisionRepr divs;
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 470d483cbb56481..6a7f05999df2ceb 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -338,7 +338,7 @@ class SimplexBase {
   unsigned nSymbol;
 
   /// The matrix representing the tableau.
-  Matrix tableau;
+  Matrix<MPInt> tableau;
 
   /// This is true if the tableau has been detected to be empty, false
   /// otherwise.
@@ -861,7 +861,7 @@ class Simplex : public SimplexBase {
 
   /// Reduce the given basis, starting at the specified level, using general
   /// basis reduction.
-  void reduceBasis(Matrix &basis, unsigned level);
+  void reduceBasis(Matrix<MPInt> &basis, unsigned level);
 };
 
 /// Takes a snapshot of the simplex state on construction and rolls back to the
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index a3000a26c3f3d76..d3822ed572f8ee8 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -182,7 +182,7 @@ class DivisionRepr {
   /// Each row of the Matrix represents a single division dividend. The
   /// `i^th` row represents the dividend of the variable at `divOffset + i`
   /// in the constraint system (and the `i^th` local variable).
-  Matrix dividends;
+  Matrix<MPInt> dividends;
 
   /// Denominators of each division. If a denominator of a division is `0`, the
   /// division variable is considered to not have a division representation.
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 3f000182250069d..31aff1a216bacc3 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1292,7 +1292,7 @@ mlir::getMultiAffineFunctionFromMap(AffineMap map,
          "AffineMap cannot produce divs without local representation");
 
   // TODO: We shouldn't have to do this conversion.
-  Matrix mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
+  Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
   for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
     for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
       mat(i, j) = flattenedExprs[i][j];
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 75c6adbf6bbc2b8..ba4ce32c355b4d3 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -676,7 +676,7 @@ bool IntegerRelation::isEmptyByGCDTest() const {
 //
 // It is sufficient to check the perpendiculars of the constraints, as the set
 // of perpendiculars which are bounded must span all bounded directions.
-Matrix IntegerRelation::getBoundedDirections() const {
+Matrix<MPInt> IntegerRelation::getBoundedDirections() const {
   // Note that it is necessary to add the equalities too (which the constructor
   // does) even though we don't need to check if they are bounded; whether an
   // inequality is bounded or not depends on what other constraints, including
@@ -697,7 +697,7 @@ Matrix IntegerRelation::getBoundedDirections() const {
   // The direction vector is given by the coefficients and does not include the
   // constant term, so the matrix has one fewer column.
   unsigned dirsNumCols = getNumCols() - 1;
-  Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
+  Matrix<MPInt> dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
 
   // Copy the bounded inequalities.
   unsigned row = 0;
@@ -783,7 +783,7 @@ IntegerRelation::findIntegerSample() const {
   // m is a matrix containing, in each row, a vector in which S is
   // bounded, such that the linear span of all these dimensions contains all
   // bounded dimensions in S.
-  Matrix m = getBoundedDirections();
+  Matrix<MPInt> m = getBoundedDirections();
   // In column echelon form, each row of m occupies only the first rank(m)
   // columns and has zeros on the other columns. The transform T that brings S
   // to column echelon form is unimodular as well, so this is a suitable
diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
index e7ad3ecf4306d38..d25e76d9229f605 100644
--- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp
+++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
@@ -12,11 +12,11 @@
 using namespace mlir;
 using namespace presburger;
 
-LinearTransform::LinearTransform(Matrix &&oMatrix) : matrix(oMatrix) {}
-LinearTransform::LinearTransform(const Matrix &oMatrix) : matrix(oMatrix) {}
+LinearTransform::LinearTransform(Matrix<MPInt> &&oMatrix) : matrix(oMatrix) {}
+LinearTransform::LinearTransform(const Matrix<MPInt> &oMatrix) : matrix(oMatrix) {}
 
 std::pair<unsigned, LinearTransform>
-LinearTransform::makeTransformToColumnEchelon(const Matrix &m) {
+LinearTransform::makeTransformToColumnEchelon(const Matrix<MPInt> &m) {
   // Compute the hermite normal form of m. This, is by definition, is in column
   // echelon form.
   auto [h, u] = m.computeHermiteNormalForm();
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 4ee81c61a53a3b5..c19e5d8d49fec37 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -7,13 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/Support/MathExtras.h"
 
 using namespace mlir;
 using namespace presburger;
 
-Matrix::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
+template <typename T> Matrix<T>::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
                unsigned reservedColumns)
     : nRows(rows), nColumns(columns),
       nReservedColumns(std::max(nColumns, reservedColumns)),
@@ -21,27 +22,27 @@ Matrix::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
   data.reserve(std::max(nRows, reservedRows) * nReservedColumns);
 }
 
-Matrix Matrix::identity(unsigned dimension) {
+template <typename T> Matrix<T> Matrix<T>::identity(unsigned dimension) {
   Matrix matrix(dimension, dimension);
   for (unsigned i = 0; i < dimension; ++i)
     matrix(i, i) = 1;
   return matrix;
 }
 
-unsigned Matrix::getNumReservedRows() const {
+template <typename T> unsigned Matrix<T>::getNumReservedRows() const {
   return data.capacity() / nReservedColumns;
 }
 
-void Matrix::reserveRows(unsigned rows) {
+template <typename T> void Matrix<T>::reserveRows(unsigned rows) {
   data.reserve(rows * nReservedColumns);
 }
 
-unsigned Matrix::appendExtraRow() {
+template <typename T> unsigned Matrix<T>::appendExtraRow() {
   resizeVertically(nRows + 1);
   return nRows - 1;
 }
 
-unsigned Matrix::appendExtraRow(ArrayRef<MPInt> elems) {
+template <typename T> unsigned Matrix<T>::appendExtraRow(ArrayRef<T> elems) {
   assert(elems.size() == nColumns && "elems must match row length!");
   unsigned row = appendExtraRow();
   for (unsigned col = 0; col < nColumns; ++col)
@@ -49,24 +50,24 @@ unsigned Matrix::appendExtraRow(ArrayRef<MPInt> elems) {
   return row;
 }
 
-void Matrix::resizeHorizontally(unsigned newNColumns) {
+template <typename T> void Matrix<T>::resizeHorizontally(unsigned newNColumns) {
   if (newNColumns < nColumns)
     removeColumns(newNColumns, nColumns - newNColumns);
   if (newNColumns > nColumns)
     insertColumns(nColumns, newNColumns - nColumns);
 }
 
-void Matrix::resize(unsigned newNRows, unsigned newNColumns) {
+template <typename T> void Matrix<T>::resize(unsigned newNRows, unsigned newNColumns) {
   resizeHorizontally(newNColumns);
   resizeVertically(newNRows);
 }
 
-void Matrix::resizeVertically(unsigned newNRows) {
+template <typename T> void Matrix<T>::resizeVertically(unsigned newNRows) {
   nRows = newNRows;
   data.resize(nRows * nReservedColumns);
 }
 
-void Matrix::swapRows(unsigned row, unsigned otherRow) {
+template <typename T> void Matrix<T>::swapRows(unsigned row, unsigned otherRow) {
   assert((row < getNumRows() && otherRow < getNumRows()) &&
          "Given row out of bounds");
   if (row == otherRow)
@@ -75,7 +76,7 @@ void Matrix::swapRows(unsigned row, unsigned otherRow) {
     std::swap(at(row, col), at(otherRow, col));
 }
 
-void Matrix::swapColumns(unsigned column, unsigned otherColumn) {
+template <typename T> void Matrix<T>::swapColumns(unsigned column, unsigned otherColumn) {
   assert((column < getNumColumns() && otherColumn < getNumColumns()) &&
          "Given column out of bounds");
   if (column == otherColumn)
@@ -84,23 +85,23 @@ void Matrix::swapColumns(unsigned column, unsigned otherColumn) {
     std::swap(at(row, column), at(row, otherColumn));
 }
 
-MutableArrayRef<MPInt> Matrix::getRow(unsigned row) {
+template <typename T> MutableArrayRef<T> Matrix<T>::getRow(unsigned row) {
   return {&data[row * nReservedColumns], nColumns};
 }
 
-ArrayRef<MPInt> Matrix::getRow(unsigned row) const {
+template <typename T> ArrayRef<T> Matrix<T>::getRow(unsigned row) const {
   return {&data[row * nReservedColumns], nColumns};
 }
 
-void Matrix::setRow(unsigned row, ArrayRef<MPInt> elems) {
+template <typename T> void Matrix<T>::setRow(unsigned row, ArrayRef<T> elems) {
   assert(elems.size() == getNumColumns() &&
          "elems size must match row length!");
   for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
     at(row, i) = elems[i];
 }
 
-void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); }
-void Matrix::insertColumns(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::insertColumn(unsigned pos) { insertColumns(pos, 1); }
+template <typename T> void Matrix<T>::insertColumns(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos <= nColumns);
@@ -115,7 +116,7 @@ void Matrix::insertColumns(unsigned pos, unsigned count) {
     for (int ci = nReservedColumns - 1; ci >= 0; --ci) {
       unsigned r = ri;
       unsigned c = ci;
-      MPInt &dest = data[r * nReservedColumns + c];
+      T &dest = data[r * nReservedColumns + c];
       if (c >= nColumns) { // NOLINT
         // Out of bounds columns are zero-initialized. NOLINT because clang-tidy
         // complains about this branch being the same as the c >= pos one.
@@ -141,8 +142,8 @@ void Matrix::insertColumns(unsigned pos, unsigned count) {
   }
 }
 
-void Matrix::removeColumn(unsigned pos) { removeColumns(pos, 1); }
-void Matrix::removeColumns(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::removeColumn(unsigned pos) { removeColumns(pos, 1); }
+template <typename T> void Matrix<T>::removeColumns(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos + count - 1 < nColumns);
@@ -155,8 +156,8 @@ void Matrix::removeColumns(unsigned pos, unsigned count) {
   nColumns -= count;
 }
 
-void Matrix::insertRow(unsigned pos) { insertRows(pos, 1); }
-void Matrix::insertRows(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::insertRow(unsigned pos) { insertRows(pos, 1); }
+template <typename T> void Matrix<T>::insertRows(unsigned pos, unsigned count) {
   if (count == 0)
     return;
 
@@ -169,8 +170,8 @@ void Matrix::insertRows(unsigned pos, unsigned count) {
       at(r, c) = 0;
 }
 
-void Matrix::removeRow(unsigned pos) { removeRows(pos, 1); }
-void Matrix::removeRows(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::removeRow(unsigned pos) { removeRows(pos, 1); }
+template <typename T> void Matrix<T>::removeRows(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos + count - 1 <= nRows);
@@ -179,76 +180,76 @@ void Matrix::removeRows(unsigned pos, unsigned count) {
   resizeVertically(nRows - count);
 }
 
-void Matrix::copyRow(unsigned sourceRow, unsigned targetRow) {
+template <typename T> void Matrix<T>::copyRow(unsigned sourceRow, unsigned targetRow) {
   if (sourceRow == targetRow)
     return;
   for (unsigned c = 0; c < nColumns; ++c)
     at(targetRow, c) = at(sourceRow, c);
 }
 
-void Matrix::fillRow(unsigned row, const MPInt &value) {
+template <typename T> void Matrix<T>::fillRow(unsigned row, const T &value) {
   for (unsigned col = 0; col < nColumns; ++col)
     at(row, col) = value;
 }
 
-void Matrix::addToRow(unsigned sourceRow, unsigned targetRow,
-                      const MPInt &scale) {
+template <typename T> void Matrix<T>::addToRow(unsigned sourceRow, unsigned targetRow,
+                      const T &scale) {
   addToRow(targetRow, getRow(sourceRow), scale);
 }
 
-void Matrix::addToRow(unsigned row, ArrayRef<MPInt> rowVec,
-                      const MPInt &scale) {
+template <typename T> void Matrix<T>::addToRow(unsigned row, ArrayRef<T> rowVec,
+                      const T &scale) {
   if (scale == 0)
     return;
   for (unsigned col = 0; col < nColumns; ++col)
-    at(row, col) += scale * rowVec[col];
+    at(row, col) = at(row, col) + scale * rowVec[col];
 }
 
-void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,
-                         const MPInt &scale) {
+template <typename T> void Matrix<T>::addToColumn(unsigned sourceColumn, unsigned targetColumn,
+                         const T &scale) {
   if (scale == 0)
     return;
   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
-    at(row, targetColumn) += scale * at(row, sourceColumn);
+    at(row, targetColumn) = at(row, targetColumn) + scale * at(row, sourceColumn);
 }
 
-void Matrix::negateColumn(unsigned column) {
+template <typename T> void Matrix<T>::negateColumn(unsigned column) {
   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
     at(row, column) = -at(row, column);
 }
 
-void Matrix::negateRow(unsigned row) {
+template <typename T> void Matrix<T>::negateRow(unsigned row) {
   for (unsigned column = 0, e = getNumColumns(); column < e; ++column)
     at(row, column) = -at(row, column);
 }
 
-MPInt Matrix::normalizeRow(unsigned row, unsigned cols) {
+template <> MPInt Matrix<MPInt>::normalizeRow(unsigned row, unsigned cols) {
   return normalizeRange(getRow(row).slice(0, cols));
 }
 
-MPInt Matrix::normalizeRow(unsigned row) {
+template <> MPInt Matrix<MPInt>::normalizeRow(unsigned row) {
   return normalizeRow(row, getNumColumns());
 }
 
-SmallVector<MPInt, 8> Matrix::preMultiplyWithRow(ArrayRef<MPInt> rowVec) const {
+template <typename T> SmallVector<T, 8> Matrix<T>::preMultiplyWithRow(ArrayRef<T> rowVec) const {
   assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
 
-  SmallVector<MPInt, 8> result(getNumColumns(), MPInt(0));
+  SmallVector<T, 8> result(getNumColumns(), T(0));
   for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
     for (unsigned i = 0, e = getNumRows(); i < e; ++i)
-      result[col] += rowVec[i] * at(i, col);
+      result[col] = result[col] + rowVec[i] * at(i, col);
   return result;
 }
 
-SmallVector<MPInt, 8>
-Matrix::postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
+template <typename T> SmallVector<T, 8>
+Matrix<T>::postMultiplyWithColumn(ArrayRef<T> colVec) const {
   assert(getNumColumns() == colVec.size() &&
          "Invalid column vector dimension!");
 
-  SmallVector<MPInt, 8> result(getNumRows(), MPInt(0));
+  SmallVector<T, 8> result(getNumRows(), T(0));
   for (unsigned row = 0, e = getNumRows(); row < e; row++)
     for (unsigned i = 0, e = getNumColumns(); i < e; i++)
-      result[row] += at(row, i) * colVec[i];
+      result[row] = result[row] + at(row, i) * colVec[i];
   return result;
 }
 
@@ -257,8 +258,8 @@ Matrix::postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
 /// sourceCol. This brings M(row, targetCol) to the range [0, M(row,
 /// sourceCol)). Apply the same column operation to otherMatrix, with the same
 /// integer multiple.
-static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
-                                    unsigned targetCol, Matrix &otherMatrix) {
+static void modEntryColumnOperation(Matrix<MPInt> &m, unsigned row, unsigned sourceCol,
+                                    unsigned targetCol, Matrix<MPInt> &otherMatrix) {
   assert(m(row, sourceCol) != 0 && "Cannot divide by zero!");
   assert(m(row, sourceCol) > 0 && "Source must be positive!");
   MPInt ratio = -floorDiv(m(row, targetCol), m(row, sourceCol));
@@ -266,12 +267,12 @@ static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
   otherMatrix.addToColumn(sourceCol, targetCol, ratio);
 }
 
-std::pair<Matrix, Matrix> Matrix::computeHermiteNormalForm() const {
+template <> std::pair<Matrix<MPInt>, Matrix<MPInt>> Matrix<MPInt>::computeHermiteNormalForm() const {
   // We start with u as an identity matrix and perform operations on h until h
   // is in hermite normal form. We apply the same sequence of operations on u to
   // obtain a transform that takes h to hermite normal form.
-  Matrix h = *this;
-  Matrix u = Matrix::identity(h.getNumColumns());
+  Matrix<MPInt> h = *this;
+  Matrix<MPInt> u = Matrix<MPInt>::identity(h.getNumColumns());
 
   unsigned echelonCol = 0;
   // Invariant: in all rows above row, all columns from echelonCol onwards
@@ -352,7 +353,7 @@ std::pair<Matrix, Matrix> Matrix::computeHermiteNormalForm() const {
   return {h, u};
 }
 
-void Matrix::print(raw_ostream &os) const {
+template <typename T> void Matrix<T>::print(raw_ostream &os) const {
   for (unsigned row = 0; row < nRows; ++row) {
     for (unsigned column = 0; column < nColumns; ++column)
       os << at(row, column) << ' ';
@@ -360,9 +361,9 @@ void Matrix::print(raw_ostream &os) const {
   }
 }
 
-void Matrix::dump() const { print(llvm::errs()); }
+template <typename T> void Matrix<T>::dump() const { print(llvm::errs()); }
 
-bool Matrix::hasConsistentState() const {
+template <typename T> bool Matrix<T>::hasConsistentState() const {
   if (data.size() != nRows * nReservedColumns)
     return false;
   if (nColumns > nReservedColumns)
@@ -375,3 +376,12 @@ bool Matrix::hasConsistentState() const {
 #endif
   return true;
 }
+
+namespace mlir
+{
+  namespace presburger
+  {
+    template class Matrix<MPInt>;
+    template class Matrix<Fraction>;
+  }
+}
\ No newline at end of file
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 61c39bd315f1878..a6d02797af1a301 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -1801,7 +1801,7 @@ class presburger::GBRSimplex {
 ///
 /// When incrementing i, no cached f values get invalidated. However, the cached
 /// duals do get invalidated as the duals for the higher levels are different.
-void Simplex::reduceBasis(Matrix &basis, unsigned level) {
+void Simplex::reduceBasis(Matrix<MPInt> &basis, unsigned level) {
   const Fraction epsilon(3, 4);
 
   if (level == basis.getNumRows() - 1)
@@ -1975,7 +1975,7 @@ std::optional<SmallVector<MPInt, 8>> Simplex::findIntegerSample() {
     return {};
 
   unsigned nDims = var.size();
-  Matrix basis = Matrix::identity(nDims);
+  Matrix<MPInt> basis = Matrix<MPInt>::identity(nDims);
 
   unsigned level = 0;
   // The snapshot just before constraining a direction to a value at each level.
diff --git a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
index 32d9e532e1f67dc..07c1f9069bca21c 100644
--- a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
@@ -13,7 +13,7 @@
 using namespace mlir;
 using namespace presburger;
 
-void testColumnEchelonForm(const Matrix &m, unsigned expectedRank) {
+void testColumnEchelonForm(const Matrix<MPInt> &m, unsigned expectedRank) {
   unsigned lastAllowedNonZeroCol = 0;
   std::pair<unsigned, LinearTransform> result =
       LinearTransform::makeTransformToColumnEchelon(m);
@@ -42,21 +42,21 @@ void testColumnEchelonForm(const Matrix &m, unsigned expectedRank) {
 
 TEST(LinearTransformTest, transformToColumnEchelonTest) {
   // m1, m2, m3 are rank 1 matrices -- the first and second rows are identical.
-  Matrix m1(2, 2);
+  Matrix<MPInt> m1(2, 2);
   m1(0, 0) = 4;
   m1(0, 1) = -7;
   m1(1, 0) = 4;
   m1(1, 1) = -7;
   testColumnEchelonForm(m1, 1u);
 
-  Matrix m2(2, 2);
+  Matrix<MPInt> m2(2, 2);
   m2(0, 0) = -4;
   m2(0, 1) = 7;
   m2(1, 0) = 4;
   m2(1, 1) = -7;
   testColumnEchelonForm(m2, 1u);
 
-  Matrix m3(2, 2);
+  Matrix<MPInt> m3(2, 2);
   m3(0, 0) = -4;
   m3(0, 1) = -7;
   m3(1, 0) = -4;
@@ -64,21 +64,21 @@ TEST(LinearTransformTest, transformToColumnEchelonTest) {
   testColumnEchelonForm(m3, 1u);
 
   // m4, m5, m6 are rank 2 matrices -- the first and second rows are different.
-  Matrix m4(2, 2);
+  Matrix<MPInt> m4(2, 2);
   m4(0, 0) = 4;
   m4(0, 1) = -7;
   m4(1, 0) = -4;
   m4(1, 1) = -7;
   testColumnEchelonForm(m4, 2u);
 
-  Matrix m5(2, 2);
+  Matrix<MPInt> m5(2, 2);
   m5(0, 0) = -4;
   m5(0, 1) = 7;
   m5(1, 0) = 4;
   m5(1, 1) = 7;
   testColumnEchelonForm(m5, 2u);
 
-  Matrix m6(2, 2);
+  Matrix<MPInt> m6(2, 2);
   m6(0, 0) = -4;
   m6(0, 1) = -7;
   m6(1, 0) = 4;
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 5a1a827e6bb9a88..7a226936c5751eb 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
 #include "./Utils.h"
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
@@ -15,7 +16,7 @@ using namespace mlir;
 using namespace presburger;
 
 TEST(MatrixTest, ReadWrite) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = 10 * row + col;
@@ -25,7 +26,7 @@ TEST(MatrixTest, ReadWrite) {
 }
 
 TEST(MatrixTest, SwapColumns) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = col == 3 ? 1 : 0;
@@ -47,7 +48,7 @@ TEST(MatrixTest, SwapColumns) {
 }
 
 TEST(MatrixTest, SwapRows) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = row == 2 ? 1 : 0;
@@ -69,7 +70,7 @@ TEST(MatrixTest, SwapRows) {
 }
 
 TEST(MatrixTest, resizeVertically) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -94,7 +95,7 @@ TEST(MatrixTest, resizeVertically) {
 }
 
 TEST(MatrixTest, insertColumns) {
-  Matrix mat(5, 5, 5, 10);
+  Matrix<MPInt> mat(5, 5, 5, 10);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -131,7 +132,7 @@ TEST(MatrixTest, insertColumns) {
 }
 
 TEST(MatrixTest, insertRows) {
-  Matrix mat(5, 5, 5, 10);
+  Matrix<MPInt> mat(5, 5, 5, 10);
   ASSERT_TRUE(mat.hasConsistentState());
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
@@ -169,7 +170,7 @@ TEST(MatrixTest, insertRows) {
 }
 
 TEST(MatrixTest, resize) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -193,8 +194,8 @@ TEST(MatrixTest, resize) {
       EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
 }
 
-static void checkHermiteNormalForm(const Matrix &mat,
-                                   const Matrix &hermiteForm) {
+static void checkHermiteNormalForm(const Matrix<MPInt> &mat,
+                                   const Matrix<MPInt> &hermiteForm) {
   auto [h, u] = mat.computeHermiteNormalForm();
 
   for (unsigned row = 0; row < mat.getNumRows(); row++)
@@ -208,42 +209,42 @@ TEST(MatrixTest, computeHermiteNormalForm) {
 
   {
     // Hermite form of a unimodular matrix is the identity matrix.
-    Matrix mat = makeMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
-    Matrix hermiteForm = makeMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
+    Matrix<MPInt> mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
+    Matrix<MPInt> hermiteForm = makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
     // Hermite form of a unimodular is the identity matrix.
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4,
         {{-6, -1, -19, -20}, {0, 1, 0, 0}, {-5, 0, -15, -16}, {6, 0, 18, 19}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat =
-        makeMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
-    Matrix hermiteForm =
-        makeMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
+    Matrix<MPInt> mat =
+        makeIntMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
+    Matrix<MPInt> hermiteForm =
+        makeIntMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 }
diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/unittests/Analysis/Presburger/Parser.h
index c2c63730056e7fe..bd9b6f07664c7e7 100644
--- a/mlir/unittests/Analysis/Presburger/Parser.h
+++ b/mlir/unittests/Analysis/Presburger/Parser.h
@@ -52,7 +52,7 @@ inline MultiAffineFunction parseMultiAffineFunction(StringRef str) {
 
   // TODO: Add default constructor for MultiAffineFunction.
   MultiAffineFunction multiAff(PresburgerSpace::getRelationSpace(),
-                               Matrix(0, 1));
+                               Matrix<MPInt>(0, 1));
   if (getMultiAffineFunctionFromMap(parseAffineMap(str, &context), multiAff)
           .failed())
     llvm_unreachable(
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index c3246a09d5ae9be..8a7f86c866b7056 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -17,6 +17,7 @@
 #include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Support/LLVM.h"
 
@@ -26,9 +27,22 @@
 namespace mlir {
 namespace presburger {
 
-inline Matrix makeMatrix(unsigned numRow, unsigned numColumns,
-                         ArrayRef<SmallVector<int64_t, 8>> matrix) {
-  Matrix results(numRow, numColumns);
+inline Matrix<MPInt> makeIntMatrix(unsigned numRow, unsigned numColumns,
+                         ArrayRef<SmallVector<int, 8>> matrix) {
+  Matrix<MPInt> results(numRow, numColumns);
+  assert(matrix.size() == numRow);
+  for (unsigned i = 0; i < numRow; ++i) {
+    assert(matrix[i].size() == numColumns &&
+           "Output expression has incorrect dimensionality!");
+    for (unsigned j = 0; j < numColumns; ++j)
+      results(i, j) = MPInt(matrix[i][j]);
+  }
+  return results;
+}
+
+inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
+                         ArrayRef<SmallVector<Fraction, 8>> matrix) {
+  Matrix<Fraction> results(numRow, numColumns);
   assert(matrix.size() == numRow);
   for (unsigned i = 0; i < numRow; ++i) {
     assert(matrix[i].size() == numColumns &&

>From 04d41997d3055c9977f7c9c3a8d610243838436b Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Thu, 31 Aug 2023 17:30:17 +0100
Subject: [PATCH 2/4] Template Matrix to Matrix<T> (for MPInt and Fraction)
 with explicit instantiation Duplicate makeMatrix to makeIntMatrix and
 makeFracMatrix

Implement arithmetic operations for Fraction for compatibility
---
 .../mlir/Analysis/Presburger/Fraction.h       |  28 ++++-
 .../Analysis/Presburger/IntegerRelation.h     |   6 +-
 .../Analysis/Presburger/LinearTransform.h     |  12 +-
 .../include/mlir/Analysis/Presburger/Matrix.h |  42 +++----
 .../mlir/Analysis/Presburger/PWMAFunction.h   |   8 +-
 .../mlir/Analysis/Presburger/Simplex.h        |   4 +-
 mlir/include/mlir/Analysis/Presburger/Utils.h |   2 +-
 .../Analysis/FlatLinearValueConstraints.cpp   |   2 +-
 .../Analysis/Presburger/IntegerRelation.cpp   |   6 +-
 .../Analysis/Presburger/LinearTransform.cpp   |   6 +-
 mlir/lib/Analysis/Presburger/Matrix.cpp       | 114 ++++++++++--------
 mlir/lib/Analysis/Presburger/Simplex.cpp      |   4 +-
 .../Presburger/LinearTransformTest.cpp        |  14 +--
 .../Analysis/Presburger/MatrixTest.cpp        |  43 +++----
 mlir/unittests/Analysis/Presburger/Parser.h   |   2 +-
 mlir/unittests/Analysis/Presburger/Utils.h    |  20 ++-
 16 files changed, 182 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index c51b6c972bf8851..2cb90b708435353 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -15,6 +15,7 @@
 #define MLIR_ANALYSIS_PRESBURGER_FRACTION_H
 
 #include "mlir/Analysis/Presburger/MPInt.h"
+#include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Support/MathExtras.h"
 
 namespace mlir {
@@ -30,15 +31,15 @@ struct Fraction {
   Fraction() = default;
 
   /// Construct a Fraction from a numerator and denominator.
-  Fraction(const MPInt &oNum, const MPInt &oDen) : num(oNum), den(oDen) {
+  Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) {
     if (den < 0) {
       num = -num;
       den = -den;
     }
   }
   /// Overloads for passing literals.
-  Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {}
-  Fraction(int64_t num, const MPInt &den) : Fraction(MPInt(num), den) {}
+  Fraction(const MPInt &num, int64_t den = 1) : Fraction(num, MPInt(den)) {}
+  Fraction(int64_t num, const MPInt &den = MPInt(1)) : Fraction(MPInt(num), den) {}
   Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {}
 
   // Return the value of the fraction as an integer. This should only be called
@@ -48,6 +49,10 @@ struct Fraction {
     return num / den;
   }
 
+  llvm::raw_ostream &print(llvm::raw_ostream &os) const {
+    return os << "(" << num << "/" << den << ")";
+  }
+
   /// The numerator and denominator, respectively. The denominator is always
   /// positive.
   MPInt num{0}, den{1};
@@ -99,6 +104,23 @@ inline Fraction operator*(const Fraction &x, const Fraction &y) {
   return Fraction(x.num * y.num, x.den * y.den);
 }
 
+inline Fraction operator/(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den, x.den * y.num);
+}
+
+inline Fraction operator+(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den + x.den * y.num, x.den * y.den);
+}
+
+inline Fraction operator-(const Fraction &x, const Fraction &y) {
+  return Fraction(x.num * y.den - x.den * y.num, x.den * y.den);
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Fraction &x) {
+  x.print(os);
+  return os;
+}
+
 } // namespace presburger
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index fb1401cfbcf20f8..9a94fbe5a532bff 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -366,7 +366,7 @@ class IntegerRelation {
   /// bounded. The span of the returned vectors is guaranteed to contain all
   /// such vectors. The returned vectors are NOT guaranteed to be linearly
   /// independent. This function should not be called on empty sets.
-  Matrix getBoundedDirections() const;
+  Matrix<MPInt> getBoundedDirections() const;
 
   /// Find an integer sample point satisfying the constraints using a
   /// branch and bound algorithm with generalized basis reduction, with some
@@ -792,10 +792,10 @@ class IntegerRelation {
   PresburgerSpace space;
 
   /// Coefficients of affine equalities (in == 0 form).
-  Matrix equalities;
+  Matrix<MPInt> equalities;
 
   /// Coefficients of affine inequalities (in >= 0 form).
-  Matrix inequalities;
+  Matrix<MPInt> inequalities;
 };
 
 /// An IntegerPolyhedron represents the set of points from a PresburgerSpace
diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
index cd56951fe773f8b..686e846a16c78a7 100644
--- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
+++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
@@ -22,8 +22,8 @@ namespace presburger {
 
 class LinearTransform {
 public:
-  explicit LinearTransform(Matrix &&oMatrix);
-  explicit LinearTransform(const Matrix &oMatrix);
+  explicit LinearTransform(Matrix<MPInt> &&oMatrix);
+  explicit LinearTransform(const Matrix<MPInt> &oMatrix);
 
   // Returns a linear transform T such that MT is M in column echelon form.
   // Also returns the number of non-zero columns in MT.
@@ -32,7 +32,7 @@ class LinearTransform {
   // strictly below that of the previous column, and all columns which have only
   // zeros are at the end.
   static std::pair<unsigned, LinearTransform>
-  makeTransformToColumnEchelon(const Matrix &m);
+  makeTransformToColumnEchelon(const Matrix<MPInt> &m);
 
   // Returns an IntegerRelation having a constraint vector vT for every
   // constraint vector v in rel, where T is this transform.
@@ -50,8 +50,12 @@ class LinearTransform {
     return matrix.postMultiplyWithColumn(colVec);
   }
 
+  // Compute the determinant of the transform by converting it to row echelon
+  // form and then taking the product of the diagonal.
+  MPInt determinant();
+
 private:
-  Matrix matrix;
+  Matrix<MPInt> matrix;
 };
 
 } // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index bae1661d9ce6c60..b03737ab2f70a4a 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H
 #define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
 
-#include "mlir/Analysis/Presburger/MPInt.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
@@ -32,6 +31,7 @@ namespace presburger {
 /// (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused
 /// columns always have all zero values. The reserved rows are just reserved
 /// space in the underlying SmallVector's capacity.
+template<typename T>
 class Matrix {
 public:
   Matrix() = delete;
@@ -49,21 +49,21 @@ class Matrix {
   static Matrix identity(unsigned dimension);
 
   /// Access the element at the specified row and column.
-  MPInt &at(unsigned row, unsigned column) {
+  T &at(unsigned row, unsigned column) {
     assert(row < nRows && "Row outside of range");
     assert(column < nColumns && "Column outside of range");
     return data[row * nReservedColumns + column];
   }
 
-  MPInt at(unsigned row, unsigned column) const {
+  T at(unsigned row, unsigned column) const {
     assert(row < nRows && "Row outside of range");
     assert(column < nColumns && "Column outside of range");
     return data[row * nReservedColumns + column];
   }
 
-  MPInt &operator()(unsigned row, unsigned column) { return at(row, column); }
+  T &operator()(unsigned row, unsigned column) { return at(row, column); }
 
-  MPInt operator()(unsigned row, unsigned column) const {
+  T operator()(unsigned row, unsigned column) const {
     return at(row, column);
   }
 
@@ -87,11 +87,11 @@ class Matrix {
   void reserveRows(unsigned rows);
 
   /// Get a [Mutable]ArrayRef corresponding to the specified row.
-  MutableArrayRef<MPInt> getRow(unsigned row);
-  ArrayRef<MPInt> getRow(unsigned row) const;
+  MutableArrayRef<T> getRow(unsigned row);
+  ArrayRef<T> getRow(unsigned row) const;
 
   /// Set the specified row to `elems`.
-  void setRow(unsigned row, ArrayRef<MPInt> elems);
+  void setRow(unsigned row, ArrayRef<T> elems);
 
   /// Insert columns having positions pos, pos + 1, ... pos + count - 1.
   /// Columns that were at positions 0 to pos - 1 will stay where they are;
@@ -125,23 +125,23 @@ class Matrix {
 
   void copyRow(unsigned sourceRow, unsigned targetRow);
 
-  void fillRow(unsigned row, const MPInt &value);
-  void fillRow(unsigned row, int64_t value) { fillRow(row, MPInt(value)); }
+  void fillRow(unsigned row, const T &value);
+  void fillRow(unsigned row, int64_t value) { fillRow(row, T(value)); }
 
   /// Add `scale` multiples of the source row to the target row.
-  void addToRow(unsigned sourceRow, unsigned targetRow, const MPInt &scale);
+  void addToRow(unsigned sourceRow, unsigned targetRow, const T &scale);
   void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
-    addToRow(sourceRow, targetRow, MPInt(scale));
+    addToRow(sourceRow, targetRow, T(scale));
   }
   /// Add `scale` multiples of the rowVec row to the specified row.
-  void addToRow(unsigned row, ArrayRef<MPInt> rowVec, const MPInt &scale);
+  void addToRow(unsigned row, ArrayRef<T> rowVec, const T &scale);
 
   /// Add `scale` multiples of the source column to the target column.
   void addToColumn(unsigned sourceColumn, unsigned targetColumn,
-                   const MPInt &scale);
+                   const T &scale);
   void addToColumn(unsigned sourceColumn, unsigned targetColumn,
                    int64_t scale) {
-    addToColumn(sourceColumn, targetColumn, MPInt(scale));
+    addToColumn(sourceColumn, targetColumn, T(scale));
   }
 
   /// Negate the specified column.
@@ -152,18 +152,18 @@ class Matrix {
 
   /// Divide the first `nCols` of the specified row by their GCD.
   /// Returns the GCD of the first `nCols` of the specified row.
-  MPInt normalizeRow(unsigned row, unsigned nCols);
+  T normalizeRow(unsigned row, unsigned nCols);
   /// Divide the columns of the specified row by their GCD.
   /// Returns the GCD of the columns of the specified row.
-  MPInt normalizeRow(unsigned row);
+  T normalizeRow(unsigned row);
 
   /// The given vector is interpreted as a row vector v. Post-multiply v with
   /// this matrix, say M, and return vM.
-  SmallVector<MPInt, 8> preMultiplyWithRow(ArrayRef<MPInt> rowVec) const;
+  SmallVector<T, 8> preMultiplyWithRow(ArrayRef<T> rowVec) const;
 
   /// The given vector is interpreted as a column vector v. Pre-multiply v with
   /// this matrix, say M, and return Mv.
-  SmallVector<MPInt, 8> postMultiplyWithColumn(ArrayRef<MPInt> colVec) const;
+  SmallVector<T, 8> postMultiplyWithColumn(ArrayRef<T> colVec) const;
 
   /// Given the current matrix M, returns the matrices H, U such that H is the
   /// column hermite normal form of M, i.e. H = M * U, where U is unimodular and
@@ -192,7 +192,7 @@ class Matrix {
   unsigned appendExtraRow();
   /// Same as above, but copy the given elements into the row. The length of
   /// `elems` must be equal to the number of columns.
-  unsigned appendExtraRow(ArrayRef<MPInt> elems);
+  unsigned appendExtraRow(ArrayRef<T> elems);
 
   /// Print the matrix.
   void print(raw_ostream &os) const;
@@ -211,7 +211,7 @@ class Matrix {
 
   /// Stores the data. data.size() is equal to nRows * nReservedColumns.
   /// data.capacity() / nReservedColumns is the number of reserved rows.
-  SmallVector<MPInt, 16> data;
+  SmallVector<T, 16> data;
 };
 
 } // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index ea3456624e72d4e..0b3804fc08a60e0 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -40,13 +40,13 @@ enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
 /// value of the function at a specified point.
 class MultiAffineFunction {
 public:
-  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output)
       : space(space), output(output),
         divs(space.getNumVars() - space.getNumRangeVars()) {
     assertIsConsistent();
   }
 
-  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output,
                       const DivisionRepr &divs)
       : space(space), output(output), divs(divs) {
     assertIsConsistent();
@@ -65,7 +65,7 @@ class MultiAffineFunction {
   PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
 
   /// Get a matrix with each row representing row^th output expression.
-  const Matrix &getOutputMatrix() const { return output; }
+  const Matrix<MPInt> &getOutputMatrix() const { return output; }
   /// Get the `i^th` output expression.
   ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
@@ -124,7 +124,7 @@ class MultiAffineFunction {
   /// The function's output is a tuple of integers, with the ith element of the
   /// tuple defined by the affine expression given by the ith row of this output
   /// matrix.
-  Matrix output;
+  Matrix<MPInt> output;
 
   /// Storage for division representation for each local variable in space.
   DivisionRepr divs;
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 79a42d6c38d4113..922b0cb33168fb6 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -338,7 +338,7 @@ class SimplexBase {
   unsigned nSymbol;
 
   /// The matrix representing the tableau.
-  Matrix tableau;
+  Matrix<MPInt> tableau;
 
   /// This is true if the tableau has been detected to be empty, false
   /// otherwise.
@@ -861,7 +861,7 @@ class Simplex : public SimplexBase {
 
   /// Reduce the given basis, starting at the specified level, using general
   /// basis reduction.
-  void reduceBasis(Matrix &basis, unsigned level);
+  void reduceBasis(Matrix<MPInt> &basis, unsigned level);
 };
 
 /// Takes a snapshot of the simplex state on construction and rolls back to the
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index a3000a26c3f3d76..d3822ed572f8ee8 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -182,7 +182,7 @@ class DivisionRepr {
   /// Each row of the Matrix represents a single division dividend. The
   /// `i^th` row represents the dividend of the variable at `divOffset + i`
   /// in the constraint system (and the `i^th` local variable).
-  Matrix dividends;
+  Matrix<MPInt> dividends;
 
   /// Denominators of each division. If a denominator of a division is `0`, the
   /// division variable is considered to not have a division representation.
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 3f000182250069d..31aff1a216bacc3 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1292,7 +1292,7 @@ mlir::getMultiAffineFunctionFromMap(AffineMap map,
          "AffineMap cannot produce divs without local representation");
 
   // TODO: We shouldn't have to do this conversion.
-  Matrix mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
+  Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
   for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
     for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
       mat(i, j) = flattenedExprs[i][j];
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 6f07c364d07653c..4672de03b40693d 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -738,7 +738,7 @@ bool IntegerRelation::isEmptyByGCDTest() const {
 //
 // It is sufficient to check the perpendiculars of the constraints, as the set
 // of perpendiculars which are bounded must span all bounded directions.
-Matrix IntegerRelation::getBoundedDirections() const {
+Matrix<MPInt> IntegerRelation::getBoundedDirections() const {
   // Note that it is necessary to add the equalities too (which the constructor
   // does) even though we don't need to check if they are bounded; whether an
   // inequality is bounded or not depends on what other constraints, including
@@ -759,7 +759,7 @@ Matrix IntegerRelation::getBoundedDirections() const {
   // The direction vector is given by the coefficients and does not include the
   // constant term, so the matrix has one fewer column.
   unsigned dirsNumCols = getNumCols() - 1;
-  Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
+  Matrix<MPInt> dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
 
   // Copy the bounded inequalities.
   unsigned row = 0;
@@ -845,7 +845,7 @@ IntegerRelation::findIntegerSample() const {
   // m is a matrix containing, in each row, a vector in which S is
   // bounded, such that the linear span of all these dimensions contains all
   // bounded dimensions in S.
-  Matrix m = getBoundedDirections();
+  Matrix<MPInt> m = getBoundedDirections();
   // In column echelon form, each row of m occupies only the first rank(m)
   // columns and has zeros on the other columns. The transform T that brings S
   // to column echelon form is unimodular as well, so this is a suitable
diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
index e7ad3ecf4306d38..d25e76d9229f605 100644
--- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp
+++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
@@ -12,11 +12,11 @@
 using namespace mlir;
 using namespace presburger;
 
-LinearTransform::LinearTransform(Matrix &&oMatrix) : matrix(oMatrix) {}
-LinearTransform::LinearTransform(const Matrix &oMatrix) : matrix(oMatrix) {}
+LinearTransform::LinearTransform(Matrix<MPInt> &&oMatrix) : matrix(oMatrix) {}
+LinearTransform::LinearTransform(const Matrix<MPInt> &oMatrix) : matrix(oMatrix) {}
 
 std::pair<unsigned, LinearTransform>
-LinearTransform::makeTransformToColumnEchelon(const Matrix &m) {
+LinearTransform::makeTransformToColumnEchelon(const Matrix<MPInt> &m) {
   // Compute the hermite normal form of m. This, is by definition, is in column
   // echelon form.
   auto [h, u] = m.computeHermiteNormalForm();
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 4ee81c61a53a3b5..c19e5d8d49fec37 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -7,13 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/Support/MathExtras.h"
 
 using namespace mlir;
 using namespace presburger;
 
-Matrix::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
+template <typename T> Matrix<T>::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
                unsigned reservedColumns)
     : nRows(rows), nColumns(columns),
       nReservedColumns(std::max(nColumns, reservedColumns)),
@@ -21,27 +22,27 @@ Matrix::Matrix(unsigned rows, unsigned columns, unsigned reservedRows,
   data.reserve(std::max(nRows, reservedRows) * nReservedColumns);
 }
 
-Matrix Matrix::identity(unsigned dimension) {
+template <typename T> Matrix<T> Matrix<T>::identity(unsigned dimension) {
   Matrix matrix(dimension, dimension);
   for (unsigned i = 0; i < dimension; ++i)
     matrix(i, i) = 1;
   return matrix;
 }
 
-unsigned Matrix::getNumReservedRows() const {
+template <typename T> unsigned Matrix<T>::getNumReservedRows() const {
   return data.capacity() / nReservedColumns;
 }
 
-void Matrix::reserveRows(unsigned rows) {
+template <typename T> void Matrix<T>::reserveRows(unsigned rows) {
   data.reserve(rows * nReservedColumns);
 }
 
-unsigned Matrix::appendExtraRow() {
+template <typename T> unsigned Matrix<T>::appendExtraRow() {
   resizeVertically(nRows + 1);
   return nRows - 1;
 }
 
-unsigned Matrix::appendExtraRow(ArrayRef<MPInt> elems) {
+template <typename T> unsigned Matrix<T>::appendExtraRow(ArrayRef<T> elems) {
   assert(elems.size() == nColumns && "elems must match row length!");
   unsigned row = appendExtraRow();
   for (unsigned col = 0; col < nColumns; ++col)
@@ -49,24 +50,24 @@ unsigned Matrix::appendExtraRow(ArrayRef<MPInt> elems) {
   return row;
 }
 
-void Matrix::resizeHorizontally(unsigned newNColumns) {
+template <typename T> void Matrix<T>::resizeHorizontally(unsigned newNColumns) {
   if (newNColumns < nColumns)
     removeColumns(newNColumns, nColumns - newNColumns);
   if (newNColumns > nColumns)
     insertColumns(nColumns, newNColumns - nColumns);
 }
 
-void Matrix::resize(unsigned newNRows, unsigned newNColumns) {
+template <typename T> void Matrix<T>::resize(unsigned newNRows, unsigned newNColumns) {
   resizeHorizontally(newNColumns);
   resizeVertically(newNRows);
 }
 
-void Matrix::resizeVertically(unsigned newNRows) {
+template <typename T> void Matrix<T>::resizeVertically(unsigned newNRows) {
   nRows = newNRows;
   data.resize(nRows * nReservedColumns);
 }
 
-void Matrix::swapRows(unsigned row, unsigned otherRow) {
+template <typename T> void Matrix<T>::swapRows(unsigned row, unsigned otherRow) {
   assert((row < getNumRows() && otherRow < getNumRows()) &&
          "Given row out of bounds");
   if (row == otherRow)
@@ -75,7 +76,7 @@ void Matrix::swapRows(unsigned row, unsigned otherRow) {
     std::swap(at(row, col), at(otherRow, col));
 }
 
-void Matrix::swapColumns(unsigned column, unsigned otherColumn) {
+template <typename T> void Matrix<T>::swapColumns(unsigned column, unsigned otherColumn) {
   assert((column < getNumColumns() && otherColumn < getNumColumns()) &&
          "Given column out of bounds");
   if (column == otherColumn)
@@ -84,23 +85,23 @@ void Matrix::swapColumns(unsigned column, unsigned otherColumn) {
     std::swap(at(row, column), at(row, otherColumn));
 }
 
-MutableArrayRef<MPInt> Matrix::getRow(unsigned row) {
+template <typename T> MutableArrayRef<T> Matrix<T>::getRow(unsigned row) {
   return {&data[row * nReservedColumns], nColumns};
 }
 
-ArrayRef<MPInt> Matrix::getRow(unsigned row) const {
+template <typename T> ArrayRef<T> Matrix<T>::getRow(unsigned row) const {
   return {&data[row * nReservedColumns], nColumns};
 }
 
-void Matrix::setRow(unsigned row, ArrayRef<MPInt> elems) {
+template <typename T> void Matrix<T>::setRow(unsigned row, ArrayRef<T> elems) {
   assert(elems.size() == getNumColumns() &&
          "elems size must match row length!");
   for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
     at(row, i) = elems[i];
 }
 
-void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); }
-void Matrix::insertColumns(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::insertColumn(unsigned pos) { insertColumns(pos, 1); }
+template <typename T> void Matrix<T>::insertColumns(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos <= nColumns);
@@ -115,7 +116,7 @@ void Matrix::insertColumns(unsigned pos, unsigned count) {
     for (int ci = nReservedColumns - 1; ci >= 0; --ci) {
       unsigned r = ri;
       unsigned c = ci;
-      MPInt &dest = data[r * nReservedColumns + c];
+      T &dest = data[r * nReservedColumns + c];
       if (c >= nColumns) { // NOLINT
         // Out of bounds columns are zero-initialized. NOLINT because clang-tidy
         // complains about this branch being the same as the c >= pos one.
@@ -141,8 +142,8 @@ void Matrix::insertColumns(unsigned pos, unsigned count) {
   }
 }
 
-void Matrix::removeColumn(unsigned pos) { removeColumns(pos, 1); }
-void Matrix::removeColumns(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::removeColumn(unsigned pos) { removeColumns(pos, 1); }
+template <typename T> void Matrix<T>::removeColumns(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos + count - 1 < nColumns);
@@ -155,8 +156,8 @@ void Matrix::removeColumns(unsigned pos, unsigned count) {
   nColumns -= count;
 }
 
-void Matrix::insertRow(unsigned pos) { insertRows(pos, 1); }
-void Matrix::insertRows(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::insertRow(unsigned pos) { insertRows(pos, 1); }
+template <typename T> void Matrix<T>::insertRows(unsigned pos, unsigned count) {
   if (count == 0)
     return;
 
@@ -169,8 +170,8 @@ void Matrix::insertRows(unsigned pos, unsigned count) {
       at(r, c) = 0;
 }
 
-void Matrix::removeRow(unsigned pos) { removeRows(pos, 1); }
-void Matrix::removeRows(unsigned pos, unsigned count) {
+template <typename T> void Matrix<T>::removeRow(unsigned pos) { removeRows(pos, 1); }
+template <typename T> void Matrix<T>::removeRows(unsigned pos, unsigned count) {
   if (count == 0)
     return;
   assert(pos + count - 1 <= nRows);
@@ -179,76 +180,76 @@ void Matrix::removeRows(unsigned pos, unsigned count) {
   resizeVertically(nRows - count);
 }
 
-void Matrix::copyRow(unsigned sourceRow, unsigned targetRow) {
+template <typename T> void Matrix<T>::copyRow(unsigned sourceRow, unsigned targetRow) {
   if (sourceRow == targetRow)
     return;
   for (unsigned c = 0; c < nColumns; ++c)
     at(targetRow, c) = at(sourceRow, c);
 }
 
-void Matrix::fillRow(unsigned row, const MPInt &value) {
+template <typename T> void Matrix<T>::fillRow(unsigned row, const T &value) {
   for (unsigned col = 0; col < nColumns; ++col)
     at(row, col) = value;
 }
 
-void Matrix::addToRow(unsigned sourceRow, unsigned targetRow,
-                      const MPInt &scale) {
+template <typename T> void Matrix<T>::addToRow(unsigned sourceRow, unsigned targetRow,
+                      const T &scale) {
   addToRow(targetRow, getRow(sourceRow), scale);
 }
 
-void Matrix::addToRow(unsigned row, ArrayRef<MPInt> rowVec,
-                      const MPInt &scale) {
+template <typename T> void Matrix<T>::addToRow(unsigned row, ArrayRef<T> rowVec,
+                      const T &scale) {
   if (scale == 0)
     return;
   for (unsigned col = 0; col < nColumns; ++col)
-    at(row, col) += scale * rowVec[col];
+    at(row, col) = at(row, col) + scale * rowVec[col];
 }
 
-void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,
-                         const MPInt &scale) {
+template <typename T> void Matrix<T>::addToColumn(unsigned sourceColumn, unsigned targetColumn,
+                         const T &scale) {
   if (scale == 0)
     return;
   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
-    at(row, targetColumn) += scale * at(row, sourceColumn);
+    at(row, targetColumn) = at(row, targetColumn) + scale * at(row, sourceColumn);
 }
 
-void Matrix::negateColumn(unsigned column) {
+template <typename T> void Matrix<T>::negateColumn(unsigned column) {
   for (unsigned row = 0, e = getNumRows(); row < e; ++row)
     at(row, column) = -at(row, column);
 }
 
-void Matrix::negateRow(unsigned row) {
+template <typename T> void Matrix<T>::negateRow(unsigned row) {
   for (unsigned column = 0, e = getNumColumns(); column < e; ++column)
     at(row, column) = -at(row, column);
 }
 
-MPInt Matrix::normalizeRow(unsigned row, unsigned cols) {
+template <> MPInt Matrix<MPInt>::normalizeRow(unsigned row, unsigned cols) {
   return normalizeRange(getRow(row).slice(0, cols));
 }
 
-MPInt Matrix::normalizeRow(unsigned row) {
+template <> MPInt Matrix<MPInt>::normalizeRow(unsigned row) {
   return normalizeRow(row, getNumColumns());
 }
 
-SmallVector<MPInt, 8> Matrix::preMultiplyWithRow(ArrayRef<MPInt> rowVec) const {
+template <typename T> SmallVector<T, 8> Matrix<T>::preMultiplyWithRow(ArrayRef<T> rowVec) const {
   assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
 
-  SmallVector<MPInt, 8> result(getNumColumns(), MPInt(0));
+  SmallVector<T, 8> result(getNumColumns(), T(0));
   for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
     for (unsigned i = 0, e = getNumRows(); i < e; ++i)
-      result[col] += rowVec[i] * at(i, col);
+      result[col] = result[col] + rowVec[i] * at(i, col);
   return result;
 }
 
-SmallVector<MPInt, 8>
-Matrix::postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
+template <typename T> SmallVector<T, 8>
+Matrix<T>::postMultiplyWithColumn(ArrayRef<T> colVec) const {
   assert(getNumColumns() == colVec.size() &&
          "Invalid column vector dimension!");
 
-  SmallVector<MPInt, 8> result(getNumRows(), MPInt(0));
+  SmallVector<T, 8> result(getNumRows(), T(0));
   for (unsigned row = 0, e = getNumRows(); row < e; row++)
     for (unsigned i = 0, e = getNumColumns(); i < e; i++)
-      result[row] += at(row, i) * colVec[i];
+      result[row] = result[row] + at(row, i) * colVec[i];
   return result;
 }
 
@@ -257,8 +258,8 @@ Matrix::postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
 /// sourceCol. This brings M(row, targetCol) to the range [0, M(row,
 /// sourceCol)). Apply the same column operation to otherMatrix, with the same
 /// integer multiple.
-static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
-                                    unsigned targetCol, Matrix &otherMatrix) {
+static void modEntryColumnOperation(Matrix<MPInt> &m, unsigned row, unsigned sourceCol,
+                                    unsigned targetCol, Matrix<MPInt> &otherMatrix) {
   assert(m(row, sourceCol) != 0 && "Cannot divide by zero!");
   assert(m(row, sourceCol) > 0 && "Source must be positive!");
   MPInt ratio = -floorDiv(m(row, targetCol), m(row, sourceCol));
@@ -266,12 +267,12 @@ static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
   otherMatrix.addToColumn(sourceCol, targetCol, ratio);
 }
 
-std::pair<Matrix, Matrix> Matrix::computeHermiteNormalForm() const {
+template <> std::pair<Matrix<MPInt>, Matrix<MPInt>> Matrix<MPInt>::computeHermiteNormalForm() const {
   // We start with u as an identity matrix and perform operations on h until h
   // is in hermite normal form. We apply the same sequence of operations on u to
   // obtain a transform that takes h to hermite normal form.
-  Matrix h = *this;
-  Matrix u = Matrix::identity(h.getNumColumns());
+  Matrix<MPInt> h = *this;
+  Matrix<MPInt> u = Matrix<MPInt>::identity(h.getNumColumns());
 
   unsigned echelonCol = 0;
   // Invariant: in all rows above row, all columns from echelonCol onwards
@@ -352,7 +353,7 @@ std::pair<Matrix, Matrix> Matrix::computeHermiteNormalForm() const {
   return {h, u};
 }
 
-void Matrix::print(raw_ostream &os) const {
+template <typename T> void Matrix<T>::print(raw_ostream &os) const {
   for (unsigned row = 0; row < nRows; ++row) {
     for (unsigned column = 0; column < nColumns; ++column)
       os << at(row, column) << ' ';
@@ -360,9 +361,9 @@ void Matrix::print(raw_ostream &os) const {
   }
 }
 
-void Matrix::dump() const { print(llvm::errs()); }
+template <typename T> void Matrix<T>::dump() const { print(llvm::errs()); }
 
-bool Matrix::hasConsistentState() const {
+template <typename T> bool Matrix<T>::hasConsistentState() const {
   if (data.size() != nRows * nReservedColumns)
     return false;
   if (nColumns > nReservedColumns)
@@ -375,3 +376,12 @@ bool Matrix::hasConsistentState() const {
 #endif
   return true;
 }
+
+namespace mlir
+{
+  namespace presburger
+  {
+    template class Matrix<MPInt>;
+    template class Matrix<Fraction>;
+  }
+}
\ No newline at end of file
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index eff312b69e1de12..c90b7d54b0b7ad8 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -1801,7 +1801,7 @@ class presburger::GBRSimplex {
 ///
 /// When incrementing i, no cached f values get invalidated. However, the cached
 /// duals do get invalidated as the duals for the higher levels are different.
-void Simplex::reduceBasis(Matrix &basis, unsigned level) {
+void Simplex::reduceBasis(Matrix<MPInt> &basis, unsigned level) {
   const Fraction epsilon(3, 4);
 
   if (level == basis.getNumRows() - 1)
@@ -1975,7 +1975,7 @@ std::optional<SmallVector<MPInt, 8>> Simplex::findIntegerSample() {
     return {};
 
   unsigned nDims = var.size();
-  Matrix basis = Matrix::identity(nDims);
+  Matrix<MPInt> basis = Matrix<MPInt>::identity(nDims);
 
   unsigned level = 0;
   // The snapshot just before constraining a direction to a value at each level.
diff --git a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
index 32d9e532e1f67dc..07c1f9069bca21c 100644
--- a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
@@ -13,7 +13,7 @@
 using namespace mlir;
 using namespace presburger;
 
-void testColumnEchelonForm(const Matrix &m, unsigned expectedRank) {
+void testColumnEchelonForm(const Matrix<MPInt> &m, unsigned expectedRank) {
   unsigned lastAllowedNonZeroCol = 0;
   std::pair<unsigned, LinearTransform> result =
       LinearTransform::makeTransformToColumnEchelon(m);
@@ -42,21 +42,21 @@ void testColumnEchelonForm(const Matrix &m, unsigned expectedRank) {
 
 TEST(LinearTransformTest, transformToColumnEchelonTest) {
   // m1, m2, m3 are rank 1 matrices -- the first and second rows are identical.
-  Matrix m1(2, 2);
+  Matrix<MPInt> m1(2, 2);
   m1(0, 0) = 4;
   m1(0, 1) = -7;
   m1(1, 0) = 4;
   m1(1, 1) = -7;
   testColumnEchelonForm(m1, 1u);
 
-  Matrix m2(2, 2);
+  Matrix<MPInt> m2(2, 2);
   m2(0, 0) = -4;
   m2(0, 1) = 7;
   m2(1, 0) = 4;
   m2(1, 1) = -7;
   testColumnEchelonForm(m2, 1u);
 
-  Matrix m3(2, 2);
+  Matrix<MPInt> m3(2, 2);
   m3(0, 0) = -4;
   m3(0, 1) = -7;
   m3(1, 0) = -4;
@@ -64,21 +64,21 @@ TEST(LinearTransformTest, transformToColumnEchelonTest) {
   testColumnEchelonForm(m3, 1u);
 
   // m4, m5, m6 are rank 2 matrices -- the first and second rows are different.
-  Matrix m4(2, 2);
+  Matrix<MPInt> m4(2, 2);
   m4(0, 0) = 4;
   m4(0, 1) = -7;
   m4(1, 0) = -4;
   m4(1, 1) = -7;
   testColumnEchelonForm(m4, 2u);
 
-  Matrix m5(2, 2);
+  Matrix<MPInt> m5(2, 2);
   m5(0, 0) = -4;
   m5(0, 1) = 7;
   m5(1, 0) = 4;
   m5(1, 1) = 7;
   testColumnEchelonForm(m5, 2u);
 
-  Matrix m6(2, 2);
+  Matrix<MPInt> m6(2, 2);
   m6(0, 0) = -4;
   m6(0, 1) = -7;
   m6(1, 0) = 4;
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index 5a1a827e6bb9a88..7a226936c5751eb 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
 #include "./Utils.h"
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
@@ -15,7 +16,7 @@ using namespace mlir;
 using namespace presburger;
 
 TEST(MatrixTest, ReadWrite) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = 10 * row + col;
@@ -25,7 +26,7 @@ TEST(MatrixTest, ReadWrite) {
 }
 
 TEST(MatrixTest, SwapColumns) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = col == 3 ? 1 : 0;
@@ -47,7 +48,7 @@ TEST(MatrixTest, SwapColumns) {
 }
 
 TEST(MatrixTest, SwapRows) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   for (unsigned row = 0; row < 5; ++row)
     for (unsigned col = 0; col < 5; ++col)
       mat(row, col) = row == 2 ? 1 : 0;
@@ -69,7 +70,7 @@ TEST(MatrixTest, SwapRows) {
 }
 
 TEST(MatrixTest, resizeVertically) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -94,7 +95,7 @@ TEST(MatrixTest, resizeVertically) {
 }
 
 TEST(MatrixTest, insertColumns) {
-  Matrix mat(5, 5, 5, 10);
+  Matrix<MPInt> mat(5, 5, 5, 10);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -131,7 +132,7 @@ TEST(MatrixTest, insertColumns) {
 }
 
 TEST(MatrixTest, insertRows) {
-  Matrix mat(5, 5, 5, 10);
+  Matrix<MPInt> mat(5, 5, 5, 10);
   ASSERT_TRUE(mat.hasConsistentState());
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
@@ -169,7 +170,7 @@ TEST(MatrixTest, insertRows) {
 }
 
 TEST(MatrixTest, resize) {
-  Matrix mat(5, 5);
+  Matrix<MPInt> mat(5, 5);
   EXPECT_EQ(mat.getNumRows(), 5u);
   EXPECT_EQ(mat.getNumColumns(), 5u);
   for (unsigned row = 0; row < 5; ++row)
@@ -193,8 +194,8 @@ TEST(MatrixTest, resize) {
       EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
 }
 
-static void checkHermiteNormalForm(const Matrix &mat,
-                                   const Matrix &hermiteForm) {
+static void checkHermiteNormalForm(const Matrix<MPInt> &mat,
+                                   const Matrix<MPInt> &hermiteForm) {
   auto [h, u] = mat.computeHermiteNormalForm();
 
   for (unsigned row = 0; row < mat.getNumRows(); row++)
@@ -208,42 +209,42 @@ TEST(MatrixTest, computeHermiteNormalForm) {
 
   {
     // Hermite form of a unimodular matrix is the identity matrix.
-    Matrix mat = makeMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
-    Matrix hermiteForm = makeMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
+    Matrix<MPInt> mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
+    Matrix<MPInt> hermiteForm = makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
     // Hermite form of a unimodular is the identity matrix.
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4,
         {{-6, -1, -19, -20}, {0, 1, 0, 0}, {-5, 0, -15, -16}, {6, 0, 18, 19}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat = makeMatrix(
+    Matrix<MPInt> mat = makeIntMatrix(
         4, 4, {{3, 3, 1, 4}, {0, 1, 0, 0}, {0, 0, 19, 16}, {0, 0, 0, 3}});
-    Matrix hermiteForm = makeMatrix(
+    Matrix<MPInt> hermiteForm = makeIntMatrix(
         4, 4, {{1, 0, 0, 0}, {0, 1, 0, 0}, {1, 0, 3, 0}, {18, 0, 54, 57}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 
   {
-    Matrix mat =
-        makeMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
-    Matrix hermiteForm =
-        makeMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
+    Matrix<MPInt> mat =
+        makeIntMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
+    Matrix<MPInt> hermiteForm =
+        makeIntMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
     checkHermiteNormalForm(mat, hermiteForm);
   }
 }
diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/unittests/Analysis/Presburger/Parser.h
index c2c63730056e7fe..bd9b6f07664c7e7 100644
--- a/mlir/unittests/Analysis/Presburger/Parser.h
+++ b/mlir/unittests/Analysis/Presburger/Parser.h
@@ -52,7 +52,7 @@ inline MultiAffineFunction parseMultiAffineFunction(StringRef str) {
 
   // TODO: Add default constructor for MultiAffineFunction.
   MultiAffineFunction multiAff(PresburgerSpace::getRelationSpace(),
-                               Matrix(0, 1));
+                               Matrix<MPInt>(0, 1));
   if (getMultiAffineFunctionFromMap(parseAffineMap(str, &context), multiAff)
           .failed())
     llvm_unreachable(
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index c3246a09d5ae9be..8a7f86c866b7056 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -17,6 +17,7 @@
 #include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Support/LLVM.h"
 
@@ -26,9 +27,22 @@
 namespace mlir {
 namespace presburger {
 
-inline Matrix makeMatrix(unsigned numRow, unsigned numColumns,
-                         ArrayRef<SmallVector<int64_t, 8>> matrix) {
-  Matrix results(numRow, numColumns);
+inline Matrix<MPInt> makeIntMatrix(unsigned numRow, unsigned numColumns,
+                         ArrayRef<SmallVector<int, 8>> matrix) {
+  Matrix<MPInt> results(numRow, numColumns);
+  assert(matrix.size() == numRow);
+  for (unsigned i = 0; i < numRow; ++i) {
+    assert(matrix[i].size() == numColumns &&
+           "Output expression has incorrect dimensionality!");
+    for (unsigned j = 0; j < numColumns; ++j)
+      results(i, j) = MPInt(matrix[i][j]);
+  }
+  return results;
+}
+
+inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
+                         ArrayRef<SmallVector<Fraction, 8>> matrix) {
+  Matrix<Fraction> results(numRow, numColumns);
   assert(matrix.size() == numRow);
   for (unsigned i = 0; i < numRow; ++i) {
     assert(matrix[i].size() == numColumns &&

>From 3520fd022b0cd9335efd6d4e4eca696387b23e57 Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Fri, 1 Sep 2023 15:54:28 +0100
Subject: [PATCH 3/4] Fix rebase conflict

---
 mlir/lib/Analysis/Presburger/Simplex.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index c90b7d54b0b7ad8..01d350d1e96ece8 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -436,7 +436,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
 }
 
 void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const {
-  Matrix output(0, domainPoly.getNumVars() + 1);
+  Matrix<MPInt> output(0, domainPoly.getNumVars() + 1);
   output.reserveRows(result.lexopt.getNumOutputs());
   for (const Unknown &u : var) {
     if (u.isSymbol)

>From 6f8e2c3be657c321345f70e57ef1ec176f56ad1b Mon Sep 17 00:00:00 2001
From: Abhinav271828 <abhinav.m at research.iiit.ac.in>
Date: Mon, 4 Sep 2023 13:36:14 +0100
Subject: [PATCH 4/4] Add static assert

---
 mlir/include/mlir/Analysis/Presburger/Fraction.h |  3 +--
 mlir/include/mlir/Analysis/Presburger/Matrix.h   |  7 ++++++-
 mlir/lib/Analysis/Presburger/IntegerRelation.cpp |  2 +-
 mlir/lib/Analysis/Presburger/Matrix.cpp          | 10 +++++-----
 4 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 2cb90b708435353..e88f5d768f0c40f 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This is a simple class to represent fractions. It supports multiplication,
+// This is a simple class to represent fractions. It supports arithmetic,
 // comparison, floor, and ceiling operations.
 //
 //===----------------------------------------------------------------------===//
@@ -15,7 +15,6 @@
 #define MLIR_ANALYSIS_PRESBURGER_FRACTION_H
 
 #include "mlir/Analysis/Presburger/MPInt.h"
-#include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Support/MathExtras.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index b03737ab2f70a4a..d09981d79cac819 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -7,7 +7,8 @@
 //===----------------------------------------------------------------------===//
 //
 // This is a simple 2D matrix class that supports reading, writing, resizing,
-// swapping rows, and swapping columns.
+// swapping rows, and swapping columns. It can hold integers (MPInt) or rational
+// numbers (Fraction).
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,6 +16,8 @@
 #define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -33,6 +36,8 @@ namespace presburger {
 /// space in the underlying SmallVector's capacity.
 template<typename T>
 class Matrix {
+  // This class is not intended for general use: it supports only integers and rational numbers
+static_assert(std::is_same_v<T,MPInt> || std::is_same_v<T,Fraction>, "T must be MPInt or Fraction.");
 public:
   Matrix() = delete;
 
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 4672de03b40693d..118ed1f19ce62a0 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -304,7 +304,7 @@ SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const {
   // Get lexmax by flipping range sign in the PWMA constraints.
   for (auto &flippedPiece :
        flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) {
-    Matrix mat = flippedPiece.output.getOutputMatrix();
+    Matrix<MPInt> mat = flippedPiece.output.getOutputMatrix();
     for (unsigned i = 0, e = mat.getNumRows(); i < e; i++)
       mat.negateRow(i);
     MultiAffineFunction maf(flippedPiece.output.getSpace(), mat);
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index c19e5d8d49fec37..e526af9809ee0e2 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -379,9 +379,9 @@ template <typename T> bool Matrix<T>::hasConsistentState() const {
 
 namespace mlir
 {
-  namespace presburger
-  {
-    template class Matrix<MPInt>;
-    template class Matrix<Fraction>;
-  }
+namespace presburger
+{
+template class Matrix<MPInt>;
+template class Matrix<Fraction>;
+}
 }
\ No newline at end of file



More information about the Mlir-commits mailing list