[Mlir-commits] [mlir] 9f8cb68 - [MLIR][Presburger] Support finding integer lexmin in IntegerPolyhedron

Arjun P llvmlistbot at llvm.org
Mon Feb 21 13:02:26 PST 2022


Author: Arjun P
Date: 2022-02-21T21:02:21Z
New Revision: 9f8cb68570d886025df36445ae04d4e16e32a128

URL: https://github.com/llvm/llvm-project/commit/9f8cb68570d886025df36445ae04d4e16e32a128
DIFF: https://github.com/llvm/llvm-project/commit/9f8cb68570d886025df36445ae04d4e16e32a128.diff

LOG: [MLIR][Presburger] Support finding integer lexmin in IntegerPolyhedron

Note: this does not yet support PrebsurgerSets.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D120239

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Fraction.h
    mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 0f8ee6e01636b..c1ff333f6e441 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -25,7 +25,7 @@ namespace mlir {
 /// representable by 64-bit integers.
 struct Fraction {
   /// Default constructor initializes the represented rational number to zero.
-  Fraction() {}
+  Fraction() = default;
 
   /// Construct a Fraction from a numerator and denominator.
   Fraction(int64_t oNum, int64_t oDen) : num(oNum), den(oDen) {
@@ -35,6 +35,13 @@ struct Fraction {
     }
   }
 
+  // Return the value of the fraction as an integer. This should only be called
+  // when the fraction's value is really an integer.
+  int64_t getAsInteger() const {
+    assert(num % den == 0 && "Get as integer called on non-integral fraction!");
+    return num / den;
+  }
+
   /// The numerator and denominator, respectively. The denominator is always
   /// positive.
   int64_t num{0}, den{1};

diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
index 1a786d89f27b8..711a34950e753 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
@@ -212,6 +212,13 @@ class IntegerPolyhedron : public PresburgerLocalSpace {
   presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>>
   getRationalLexMin() const;
 
+  /// Same as above, but returns lexicographically minimal integer point.
+  /// Note: this should be used only when the lexmin is really required.
+  /// For a generic integer sampling operation, findIntegerSample is more
+  /// robust and should be preferred.
+  presburger_utils::MaybeOptimum<SmallVector<int64_t, 8>>
+  getIntegerLexMin() const;
+
   /// Swap the posA^th identifier with the posB^th identifier.
   virtual void swapId(unsigned posA, unsigned posB);
 

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 10600064710dc..d5e14f717e925 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -265,6 +265,10 @@ class SimplexBase {
   /// Returns the unknown associated with row.
   Unknown &unknownFromRow(unsigned row);
 
+  /// Add a new row to the tableau and the associated data structures. The row
+  /// is initialized to zero.
+  unsigned addZeroRow(bool makeRestricted = false);
+
   /// Add a new row to the tableau and the associated data structures.
   /// The new row is considered to be a constraint; the new Unknown lives in
   /// con.
@@ -436,6 +440,12 @@ class LexSimplex : public SimplexBase {
   /// Return the lexicographically minimum rational solution to the constraints.
   presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>> getRationalLexMin();
 
+  /// Return the lexicographically minimum integer solution to the constraints.
+  ///
+  /// Note: this should be used only when the lexmin is really needed. To obtain
+  /// any integer sample, use Simplex::findIntegerSample as that is more robust.
+  presburger_utils::MaybeOptimum<SmallVector<int64_t, 8>> getIntegerLexMin();
+
 protected:
   /// Returns the current sample point, which may contain non-integer (rational)
   /// coordinates. Returns an empty optimum when the tableau is empty.
@@ -446,6 +456,15 @@ class LexSimplex : public SimplexBase {
   presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>>
   getRationalSample() const;
 
+  /// Given a row that has a non-integer sample value, add an inequality such
+  /// that this fractional sample value is cut away from the polytope. The added
+  /// inequality will be such that no integer points are removed.
+  ///
+  /// Returns whether the cut constraint could be enforced, i.e. failure if the
+  /// cut made the polytope empty, and success if it didn't. Failure status
+  /// indicates that the polytope didn't have any integer points.
+  LogicalResult addCut(unsigned row);
+
   /// Undo the addition of the last constraint. This is only called while
   /// rolling back.
   void undoLastConstraint() final;
@@ -460,6 +479,10 @@ class LexSimplex : public SimplexBase {
   /// Otherwise, return an empty optional.
   Optional<unsigned> maybeGetViolatedRow() const;
 
+  /// Get a row corresponding to a var that has a non-integral sample value, if
+  /// one exists. Otherwise, return an empty optional.
+  Optional<unsigned> maybeGetNonIntegeralVarRow() const;
+
   /// Given two potential pivot columns for a row, return the one that results
   /// in the lexicographically smallest sample vector.
   unsigned getLexMinPivotColumn(unsigned row, unsigned colA,

diff  --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
index ce0f339967a52..5e26149303e6e 100644
--- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
@@ -92,6 +92,26 @@ IntegerPolyhedron::getRationalLexMin() const {
   return maybeLexMin;
 }
 
+MaybeOptimum<SmallVector<int64_t, 8>>
+IntegerPolyhedron::getIntegerLexMin() const {
+  assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
+  MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
+      LexSimplex(*this).getIntegerLexMin();
+
+  if (!maybeLexMin.isBounded())
+    return maybeLexMin.getKind();
+
+  // The Simplex returns the lexmin over all the variables including locals. But
+  // locals are not actually part of the space and should not be returned in the
+  // result. Since the locals are placed last in the list of identifiers, they
+  // will be minimized last in the lexmin. So simply truncating out the locals
+  // from the end of the answer gives the desired lexmin over the dimensions.
+  assert(maybeLexMin->size() == getNumIds() &&
+         "Incorrect number of vars in lexMin!");
+  maybeLexMin->resize(getNumDimAndSymbolIds());
+  return maybeLexMin;
+}
+
 unsigned IntegerPolyhedron::insertDimId(unsigned pos, unsigned num) {
   return insertId(IdKind::SetDim, pos, num);
 }

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 285fa91f34a07..79ccae57573e5 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -59,13 +59,7 @@ Simplex::Unknown &SimplexBase::unknownFromRow(unsigned row) {
   return unknownFromIndex(rowUnknown[row]);
 }
 
-/// Add a new row to the tableau corresponding to the given constant term and
-/// list of coefficients. The coefficients are specified as a vector of
-/// (variable index, coefficient) pairs.
-unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
-  assert(coeffs.size() == var.size() + 1 &&
-         "Incorrect number of coefficients!");
-
+unsigned SimplexBase::addZeroRow(bool makeRestricted) {
   ++nRow;
   // If the tableau is not big enough to accomodate the extra row, we extend it.
   if (nRow >= tableau.getNumRows())
@@ -77,6 +71,17 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
   tableau.fillRow(nRow - 1, 0);
 
   tableau(nRow - 1, 0) = 1;
+  return con.size() - 1;
+}
+
+/// Add a new row to the tableau corresponding to the given constant term and
+/// list of coefficients. The coefficients are specified as a vector of
+/// (variable index, coefficient) pairs.
+unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
+  assert(coeffs.size() == var.size() + 1 &&
+         "Incorrect number of coefficients!");
+
+  addZeroRow(makeRestricted);
   tableau(nRow - 1, 1) = coeffs.back();
   if (usingBigM) {
     // When the lexicographic pivot rule is used, instead of the variables
@@ -164,6 +169,56 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalLexMin() {
   return getRationalSample();
 }
 
+LogicalResult LexSimplex::addCut(unsigned row) {
+  int64_t denom = tableau(row, 0);
+  addZeroRow(/*makeRestricted=*/true);
+  tableau(nRow - 1, 0) = denom;
+  tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom);
+  tableau(nRow - 1, 2) = 0; // M has all factors in it.
+  for (unsigned col = 3; col < nCol; ++col)
+    tableau(nRow - 1, col) = mod(tableau(row, col), denom);
+  return moveRowUnknownToColumn(nRow - 1);
+}
+
+Optional<unsigned> LexSimplex::maybeGetNonIntegeralVarRow() const {
+  for (const Unknown &u : var) {
+    if (u.orientation == Orientation::Column)
+      continue;
+    // If the sample value is of the form (a/d)M + b/d, we need b to be
+    // divisible by d. We assume M is very large and contains all possible
+    // factors and is divisible by everything.
+    unsigned row = u.pos;
+    if (tableau(row, 1) % tableau(row, 0) != 0)
+      return row;
+  }
+  return {};
+}
+
+MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::getIntegerLexMin() {
+  while (!empty) {
+    restoreRationalConsistency();
+    if (empty)
+      return OptimumKind::Empty;
+
+    if (Optional<unsigned> maybeRow = maybeGetNonIntegeralVarRow()) {
+      // Failure occurs when the polytope is integer empty.
+      if (failed(addCut(*maybeRow)))
+        return OptimumKind::Empty;
+      continue;
+    }
+
+    MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
+    assert(!sample.isEmpty() && "If we reached here the sample should exist!");
+    if (sample.isUnbounded())
+      return OptimumKind::Unbounded;
+    return llvm::to_vector<8>(llvm::map_range(
+        *sample, [](const Fraction &f) { return f.getAsInteger(); }));
+  }
+
+  // Polytope is integer empty.
+  return OptimumKind::Empty;
+}
+
 bool LexSimplex::rowIsViolated(unsigned row) const {
   if (tableau(row, 2) < 0)
     return true;

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index d7e9b967136b5..fffbf7527f994 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
 #include "./Utils.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/IR/MLIRContext.h"
 
 #include <gmock/gmock.h>
@@ -36,29 +37,53 @@ makeSetFromConstraints(unsigned ids, ArrayRef<SmallVector<int64_t, 4>> ineqs,
   return set;
 }
 
+static void dump(ArrayRef<int64_t> vec) {
+  for (int64_t x : vec)
+    llvm::errs() << x << ' ';
+  llvm::errs() << '\n';
+}
+
 /// If fn is TestFunction::Sample (default):
-/// If hasSample is true, check that findIntegerSample returns a valid sample
-/// for the IntegerPolyhedron poly.
-/// If hasSample is false, check that findIntegerSample returns None.
+///
+///   If hasSample is true, check that findIntegerSample returns a valid sample
+///   for the IntegerPolyhedron poly. Also check that getIntegerLexmin finds a
+///   non-empty lexmin.
+///
+///   If hasSample is false, check that findIntegerSample returns None and
+///   getIntegerLexMin returns Empty.
 ///
 /// If fn is TestFunction::Empty, check that isIntegerEmpty returns the
 /// opposite of hasSample.
 static void checkSample(bool hasSample, const IntegerPolyhedron &poly,
                         TestFunction fn = TestFunction::Sample) {
   Optional<SmallVector<int64_t, 8>> maybeSample;
+  MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin;
   switch (fn) {
   case TestFunction::Sample:
     maybeSample = poly.findIntegerSample();
+    maybeLexMin = poly.getIntegerLexMin();
+
     if (!hasSample) {
       EXPECT_FALSE(maybeSample.hasValue());
       if (maybeSample.hasValue()) {
-        for (auto x : *maybeSample)
-          llvm::errs() << x << ' ';
-        llvm::errs() << '\n';
+        llvm::errs() << "findIntegerSample gave sample: ";
+        dump(*maybeSample);
+      }
+
+      EXPECT_TRUE(maybeLexMin.isEmpty());
+      if (maybeLexMin.isBounded()) {
+        llvm::errs() << "getIntegerLexMin gave sample: ";
+        dump(*maybeLexMin);
       }
     } else {
       ASSERT_TRUE(maybeSample.hasValue());
       EXPECT_TRUE(poly.containsPoint(*maybeSample));
+
+      ASSERT_FALSE(maybeLexMin.isEmpty());
+      if (maybeLexMin.isUnbounded())
+        EXPECT_TRUE(Simplex(poly).isUnbounded());
+      if (maybeLexMin.isBounded())
+        EXPECT_TRUE(poly.containsPoint(*maybeLexMin));
     }
     break;
   case TestFunction::Empty:
@@ -1138,6 +1163,31 @@ TEST(IntegerPolyhedronTest, getRationalLexMin) {
                          parsePoly("(x) : (2*x >= 0, -x - 1 >= 0)", &context));
 }
 
+void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef<int64_t> min) {
+  auto lexMin = poly.getIntegerLexMin();
+  ASSERT_TRUE(lexMin.isBounded());
+  EXPECT_EQ(ArrayRef<int64_t>(*lexMin), min);
+}
+
+void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) {
+  ASSERT_NE(kind, OptimumKind::Bounded)
+      << "Use expectRationalLexMin for bounded min";
+  EXPECT_EQ(poly.getRationalLexMin().getKind(), kind);
+}
+
+TEST(IntegerPolyhedronTest, getIntegerLexMin) {
+  MLIRContext context;
+  expectIntegerLexMin(parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2  >= "
+                                "0, 11*z + 5*y - 3*x + 7 >= 0)",
+                                &context),
+                      {-6, -4, 0});
+  // Similar to above but no lower bound on z.
+  expectNoIntegerLexMin(OptimumKind::Unbounded,
+                        parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2  "
+                                  ">= 0, -11*z + 5*y - 3*x + 7 >= 0)",
+                                  &context));
+}
+
 static void
 expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
                                       Optional<uint64_t> trueVolume,


        


More information about the Mlir-commits mailing list