[Mlir-commits] [mlir] 2dde029 - [MLIR][Presburger] Implement computation of generating function for unimodular cones (#77235)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 10 11:58:41 PST 2024


Author: Abhinav271828
Date: 2024-01-11T01:28:36+05:30
New Revision: 2dde029df8f9e3b2ece6899dc73bea226f227d11

URL: https://github.com/llvm/llvm-project/commit/2dde029df8f9e3b2ece6899dc73bea226f227d11
DIFF: https://github.com/llvm/llvm-project/commit/2dde029df8f9e3b2ece6899dc73bea226f227d11.diff

LOG: [MLIR][Presburger] Implement computation of generating function for unimodular cones (#77235)

We implement a function that computes the generating function
corresponding to a unimodular cone.
The generating function for a polytope is obtained by summing these
generating functions over all tangent cones.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Barvinok.h
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/Matrix.h
    mlir/lib/Analysis/Presburger/Barvinok.cpp
    mlir/lib/Analysis/Presburger/Matrix.cpp
    mlir/unittests/Analysis/Presburger/BarvinokTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Barvinok.h b/mlir/include/mlir/Analysis/Presburger/Barvinok.h
index 15e805860db237..213af636e5964d 100644
--- a/mlir/include/mlir/Analysis/Presburger/Barvinok.h
+++ b/mlir/include/mlir/Analysis/Presburger/Barvinok.h
@@ -24,6 +24,7 @@
 #ifndef MLIR_ANALYSIS_PRESBURGER_BARVINOK_H
 #define MLIR_ANALYSIS_PRESBURGER_BARVINOK_H
 
+#include "mlir/Analysis/Presburger/GeneratingFunction.h"
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/Matrix.h"
 #include <optional>
@@ -77,6 +78,11 @@ ConeV getDual(ConeH cone);
 /// The returned cone is pointed at the origin.
 ConeH getDual(ConeV cone);
 
+/// Compute the generating function for a unimodular cone.
+/// The input cone must be unimodular; it assert-fails otherwise.
+GeneratingFunction unimodularConeGeneratingFunction(ParamPoint vertex, int sign,
+                                                    ConeH cone);
+
 } // namespace detail
 } // namespace presburger
 } // namespace mlir

diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index cd957280eb740d..8e2c9fca0a17cb 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -221,6 +221,8 @@ class IntegerRelation {
     return getInt64Vec(inequalities.getRow(idx));
   }
 
+  inline IntMatrix getInequalities() const { return inequalities; }
+
   /// Get the number of vars of the specified kind.
   unsigned getNumVarKind(VarKind kind) const {
     return space.getNumVarKind(kind);

diff  --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 347e2e0489786f..38fac50c13536e 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -181,6 +181,9 @@ class Matrix {
   /// `elems` must be equal to the number of columns.
   unsigned appendExtraRow(ArrayRef<T> elems);
 
+  // Transpose the matrix without modifying it.
+  Matrix<T> transpose() const;
+
   /// Print the matrix.
   void print(raw_ostream &os) const;
   void dump() const;

diff  --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp
index 9152b66968a1f5..0bdc9015c3d647 100644
--- a/mlir/lib/Analysis/Presburger/Barvinok.cpp
+++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/Barvinok.h"
+#include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
 using namespace presburger;
@@ -24,7 +25,7 @@ ConeV mlir::presburger::detail::getDual(ConeH cone) {
   // is represented as a row [a1, ..., an, b]
   // and that b = 0.
 
-  for (unsigned i = 0; i < numIneq; ++i) {
+  for (auto i : llvm::seq<int>(0, numIneq)) {
     assert(cone.atIneq(i, numVar) == 0 &&
            "H-representation of cone is not centred at the origin!");
     for (unsigned j = 0; j < numVar; ++j) {
@@ -63,3 +64,83 @@ MPInt mlir::presburger::detail::getIndex(ConeV cone) {
 
   return cone.determinant();
 }
+
+/// Compute the generating function for a unimodular cone.
+/// This consists of a single term of the form
+/// sign * x^num / prod_j (1 - x^den_j)
+///
+/// sign is either +1 or -1.
+/// den_j is defined as the set of generators of the cone.
+/// num is computed by expressing the vertex as a weighted
+/// sum of the generators, and then taking the floor of the
+/// coefficients.
+GeneratingFunction mlir::presburger::detail::unimodularConeGeneratingFunction(
+    ParamPoint vertex, int sign, ConeH cone) {
+  // Consider a cone with H-representation [0  -1].
+  //                                       [-1 -2]
+  // Let the vertex be given by the matrix [ 2  2   0], with 2 params.
+  //                                       [-1 -1/2 1]
+
+  // `cone` must be unimodular.
+  assert(getIndex(getDual(cone)) == 1 && "input cone is not unimodular!");
+
+  unsigned numVar = cone.getNumVars();
+  unsigned numIneq = cone.getNumInequalities();
+
+  // Thus its ray matrix, U, is the inverse of the
+  // transpose of its inequality matrix, `cone`.
+  // The last column of the inequality matrix is null,
+  // so we remove it to obtain a square matrix.
+  FracMatrix transp = FracMatrix(cone.getInequalities()).transpose();
+  transp.removeRow(numVar);
+
+  FracMatrix generators(numVar, numIneq);
+  transp.determinant(/*inverse=*/&generators); // This is the U-matrix.
+  // Thus the generators are given by U = [2  -1].
+  //                                      [-1  0]
+
+  // The powers in the denominator of the generating
+  // function are given by the generators of the cone,
+  // i.e., the rows of the matrix U.
+  std::vector<Point> denominator(numIneq);
+  ArrayRef<Fraction> row;
+  for (auto i : llvm::seq<int>(0, numVar)) {
+    row = generators.getRow(i);
+    denominator[i] = Point(row);
+  }
+
+  // The vertex is v \in Z^{d x (n+1)}
+  // We need to find affine functions of parameters λ_i(p)
+  // such that v = Σ λ_i(p)*u_i,
+  // where u_i are the rows of U (generators)
+  // The λ_i are given by the columns of Λ = v^T U^{-1}, and
+  // we have transp = U^{-1}.
+  // Then the exponent in the numerator will be
+  // Σ -floor(-λ_i(p))*u_i.
+  // Thus we store the (exponent of the) numerator as the affine function -Λ,
+  // since the generators u_i are already stored as the exponent of the
+  // denominator. Note that the outer -1 will have to be accounted for, as it is
+  // not stored. See end for an example.
+
+  unsigned numColumns = vertex.getNumColumns();
+  unsigned numRows = vertex.getNumRows();
+  ParamPoint numerator(numColumns, numRows);
+  SmallVector<Fraction> ithCol(numRows);
+  for (auto i : llvm::seq<int>(0, numColumns)) {
+    for (auto j : llvm::seq<int>(0, numRows))
+      ithCol[j] = vertex(j, i);
+    numerator.setRow(i, transp.preMultiplyWithRow(ithCol));
+    numerator.negateRow(i);
+  }
+  // Therefore Λ will be given by [ 1    0 ] and the negation of this will be
+  //                              [ 1/2 -1 ]
+  //                              [ -1  -2 ]
+  // stored as the numerator.
+  // Algebraically, the numerator exponent is
+  // [ -2 ⌊ - N - M/2 + 1 ⌋ + 1 ⌊ 0 + M + 2 ⌋ ] -> first  COLUMN of U is [2, -1]
+  // [  1 ⌊ - N - M/2 + 1 ⌋ + 0 ⌊ 0 + M + 2 ⌋ ] -> second COLUMN of U is [-1, 0]
+
+  return GeneratingFunction(numColumns - 1, SmallVector<int>(1, sign),
+                            std::vector({numerator}),
+                            std::vector({denominator}));
+}

diff  --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index b68a7b7004bba9..349520747c5d6b 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -62,6 +62,16 @@ unsigned Matrix<T>::appendExtraRow(ArrayRef<T> elems) {
   return row;
 }
 
+template <typename T>
+Matrix<T> Matrix<T>::transpose() const {
+  Matrix<T> transp(nColumns, nRows);
+  for (unsigned row = 0; row < nRows; ++row)
+    for (unsigned col = 0; col < nColumns; ++col)
+      transp(col, row) = at(row, col);
+
+  return transp;
+}
+
 template <typename T>
 void Matrix<T>::resizeHorizontally(unsigned newNColumns) {
   if (newNColumns < nColumns)

diff  --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
index b88baa6c6b48a4..2936d95c802e9c 100644
--- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
@@ -46,3 +46,39 @@ TEST(BarvinokTest, getIndex) {
       4, 4, {{4, 2, 5, 1}, {4, 1, 3, 6}, {8, 2, 5, 6}, {5, 2, 5, 7}});
   EXPECT_EQ(getIndex(cone), cone.determinant());
 }
+
+// The following cones and vertices are randomly generated
+// (s.t. the cones are unimodular) and the generating functions
+// are computed. We check that the results contain the correct
+// matrices.
+TEST(BarvinokTest, unimodularConeGeneratingFunction) {
+  ConeH cone = defineHRep(2);
+  cone.addInequality({0, -1, 0});
+  cone.addInequality({-1, -2, 0});
+
+  ParamPoint vertex =
+      makeFracMatrix(2, 3, {{2, 2, 0}, {-1, -Fraction(1, 2), 1}});
+
+  GeneratingFunction gf = unimodularConeGeneratingFunction(vertex, 1, cone);
+
+  EXPECT_EQ_REPR_GENERATINGFUNCTION(
+      gf, GeneratingFunction(
+              2, {1},
+              {makeFracMatrix(3, 2, {{-1, 0}, {-Fraction(1, 2), 1}, {1, 2}})},
+              {{{2, -1}, {-1, 0}}}));
+
+  cone = defineHRep(3);
+  cone.addInequality({7, 1, 6, 0});
+  cone.addInequality({9, 1, 7, 0});
+  cone.addInequality({8, -1, 1, 0});
+
+  vertex = makeFracMatrix(3, 2, {{5, 2}, {6, 2}, {7, 1}});
+
+  gf = unimodularConeGeneratingFunction(vertex, 1, cone);
+
+  EXPECT_EQ_REPR_GENERATINGFUNCTION(
+      gf,
+      GeneratingFunction(
+          1, {1}, {makeFracMatrix(2, 3, {{-83, -100, -41}, {-22, -27, -15}})},
+          {{{8, 47, -17}, {-7, -41, 15}, {1, 5, -2}}}));
+}


        


More information about the Mlir-commits mailing list