[Mlir-commits] [mlir] da92f92 - [MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin

Arjun P llvmlistbot at llvm.org
Mon Apr 4 16:25:00 PDT 2022


Author: Arjun P
Date: 2022-04-05T00:24:57+01:00
New Revision: da92f92621e28a56fe8ad79d82eb60e436bf1d39

URL: https://github.com/llvm/llvm-project/commit/da92f92621e28a56fe8ad79d82eb60e436bf1d39
DIFF: https://github.com/llvm/llvm-project/commit/da92f92621e28a56fe8ad79d82eb60e436bf1d39.diff

LOG: [MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin

Add support for computing the symbolic integer lexmin of a polyhedron.
This finds, for every assignment to the symbols, the lexicographically
minimum value attained by the dimensions. For example, the symbolic lexmin
of the set

`(x, y)[a, b, c] : (a <= x, b <= x, x <= c)`

can be written as

```
x = a if b <= a, a <= c
x = b if a <  b, b <= c
```

This also finds the set of assignments to the symbols that make the lexmin unbounded.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122985

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/Matrix.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/Matrix.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 6add6e8aa9b24..709e4f8438356 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -536,6 +536,7 @@ class IntegerRelation : public PresburgerSpace {
   Matrix inequalities;
 };
 
+struct SymbolicLexMin;
 /// An IntegerPolyhedron is a PresburgerSpace subject to affine
 /// constraints. Affine constraints can be inequalities or equalities in the
 /// form:
@@ -593,6 +594,28 @@ class IntegerPolyhedron : public IntegerRelation {
   /// column position (i.e., not relative to the kind of identifier) of the
   /// first added identifier.
   unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+
+  /// Compute the symbolic integer lexmin of the polyhedron.
+  /// This finds, for every assignment to the symbols, the lexicographically
+  /// minimum value attained by the dimensions. For example, the symbolic lexmin
+  /// of the set
+  ///
+  /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c)
+  ///
+  /// can be written as
+  ///
+  /// x = a if b <= a, a <= c
+  /// x = b if a <  b, b <= c
+  ///
+  /// This function is stored in the `lexmin` function in the result.
+  /// Some assignments to the symbols might make the set empty.
+  /// Such points are not part of the function's domain.
+  /// In the above example, this happens when max(a, b) > c.
+  ///
+  /// For some values of the symbols, the lexmin may be unbounded.
+  /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
+  /// `PresburgerSet`, `unboundedDomain`.
+  SymbolicLexMin findSymbolicIntegerLexMin() const;
 };
 
 } // namespace presburger

diff  --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 940b88d8148f4..e2ad543070a4b 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -151,6 +151,9 @@ class Matrix {
 
   /// Add an extra row at the bottom of the matrix and return its position.
   unsigned appendExtraRow();
+  /// Same as above, but copy the given elements into the row. The length of
+  /// `elems` must be equal to the number of columns.
+  unsigned appendExtraRow(ArrayRef<int64_t> elems);
 
   /// Print the matrix.
   void print(raw_ostream &os) const;

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index ce0d77da9bc2c..f4bffe5b4e7a4 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -106,6 +106,11 @@ class MultiAffineFunction : protected IntegerPolyhedron {
   /// outside the domain, an empty optional is returned.
   Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
 
+  /// Truncate the output dimensions to the first `count` dimensions.
+  ///
+  /// TODO: refactor so that this can be accomplished through removeIdRange.
+  void truncateOutput(unsigned count);
+
   void print(raw_ostream &os) const;
   void dump() const;
 
@@ -165,6 +170,11 @@ class PWMAFunction : public PresburgerSpace {
   /// value at every point in the domain.
   bool isEqual(const PWMAFunction &other) const;
 
+  /// Truncate the output dimensions to the first `count` dimensions.
+  ///
+  /// TODO: refactor so that this can be accomplished through removeIdRange.
+  void truncateOutput(unsigned count);
+
   void print(raw_ostream &os) const;
   void dump() const;
 

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 66d408dbf8b69..67a4b5f68e202 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -18,6 +18,7 @@
 #include "mlir/Analysis/Presburger/Fraction.h"
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -41,8 +42,9 @@ class GBRSimplex;
 /// these constraints that are redundant, i.e. a subset of constraints that
 /// doesn't constrain the affine set further after adding the non-redundant
 /// constraints. The LexSimplex class provides support for computing the
-/// lexicographical minimum of an IntegerRelation. Both these classes can be
-/// constructed from an IntegerRelation, and both inherit common
+/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class
+/// provides support for computing symbolic lexicographic minimums. All of these
+/// classes can be constructed from an IntegerRelation, and all inherit common
 /// functionality from SimplexBase.
 ///
 /// The implementations of the Simplex and SimplexBase classes, other than the
@@ -72,19 +74,22 @@ class GBRSimplex;
 /// respectively. As described above, the first column is the common
 /// denominator. The second column represents the constant term, explained in
 /// more detail below. These two are _fixed columns_; they always retain their
-/// position as the first and second columns. Additionally, LexSimplex stores
-/// a so-call big M parameter (explained below) in the third column, so
-/// LexSimplex has three fixed columns.
+/// position as the first and second columns. Additionally, LexSimplexBase
+/// stores a so-call big M parameter (explained below) in the third column, so
+/// LexSimplexBase has three fixed columns. Finally, SymbolicLexSimplex has
+/// `nSymbol` variables designated as symbols. These occupy the next `nSymbol`
+/// columns, viz. the columns [3, 3 + nSymbol). For more information on symbols,
+/// see LexSimplexBase and SymbolicLexSimplex.
 ///
-/// LexSimplex does not directly support variables which can be negative, so we
-/// introduce the so-called big M parameter, an artificial variable that is
+/// LexSimplexBase does not directly support variables which can be negative, so
+/// we introduce the so-called big M parameter, an artificial variable that is
 /// considered to have an arbitrarily large value. We then transform the
 /// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been
 /// added to these variables, they are now known to have non-negative values.
-/// For more details, see the documentation for LexSimplex. The big M parameter
-/// is not considered a real unknown and is not stored in the `var` data
-/// structure; rather the tableau just has an extra fixed column for it just
-/// like the constant term.
+/// For more details, see the documentation for LexSimplexBase. The big M
+/// parameter is not considered a real unknown and is not stored in the `var`
+/// data structure; rather the tableau just has an extra fixed column for it
+/// just like the constant term.
 ///
 /// The vectors var and con store information about the variables and
 /// constraints respectively, namely, whether they are in row or column
@@ -146,8 +151,8 @@ class GBRSimplex;
 /// operation from the end until we reach the snapshot's location. SimplexBase
 /// also supports taking a snapshot including the exact set of basis unknowns;
 /// if this functionality is used, then on rolling back the exact basis will
-/// also be restored. This is used by LexSimplex because its algorithm, unlike
-/// Simplex, is sensitive to the exact basis used at a point.
+/// also be restored. This is used by LexSimplexBase because the lex algorithm,
+/// unlike `Simplex`, is sensitive to the exact basis used at a point.
 class SimplexBase {
 public:
   SimplexBase() = delete;
@@ -211,7 +216,8 @@ class SimplexBase {
   /// constant term, whereas LexSimplex has an extra fixed column for the
   /// so-called big M parameter. For more information see the documentation for
   /// LexSimplex.
-  SimplexBase(unsigned nVar, bool mustUseBigM);
+  SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
+              unsigned nSymbol);
 
   enum class Orientation { Row, Column };
 
@@ -223,11 +229,14 @@ class SimplexBase {
   /// always be non-negative and if it cannot be made non-negative without
   /// violating other constraints, the tableau is empty.
   struct Unknown {
-    Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos)
-        : pos(oPos), orientation(oOrientation), restricted(oRestricted) {}
+    Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos,
+            bool oIsSymbol = false)
+        : pos(oPos), orientation(oOrientation), restricted(oRestricted),
+          isSymbol(oIsSymbol) {}
     unsigned pos;
     Orientation orientation;
     bool restricted : 1;
+    bool isSymbol : 1;
 
     void print(raw_ostream &os) const {
       os << (orientation == Orientation::Row ? "r" : "c");
@@ -326,6 +335,10 @@ class SimplexBase {
   /// nRedundant rows.
   unsigned nRedundant;
 
+  /// The number of parameters. This must be consistent with the number of
+  /// Unknowns in `var` below that have `isSymbol` set to true.
+  unsigned nSymbol;
+
   /// The matrix representing the tableau.
   Matrix tableau;
 
@@ -363,62 +376,45 @@ class SimplexBase {
 /// introduce an artifical variable M that is considered to have a value of
 /// +infinity and instead of the variables x, y, z, we internally use variables
 /// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the
-/// documentation for Simplex for more details. The whole algorithm is performed
-/// without having to fix a "big enough" value of the big M parameter; it is
-/// just considered to be infinite throughout and it never appears in the final
-/// outputs. We will deal with sample values throughout that may in general be
-/// some linear expression involving M like pM + q or aM + b. We can compare
-/// these with each other. They have a total order:
-/// aM + b < pM + q iff a < p or (a == p and b < q).
+/// documentation for SimplexBase for more details. M is also considered to be
+/// an integer that is divisible by everything.
+///
+/// The whole algorithm is performed with M treated as a symbol;
+/// it is just considered to be infinite throughout and it never appears in the
+/// final outputs. We will deal with sample values throughout that may in
+/// general be some affine expression involving M, like pM + q or aM + b. We can
+/// compare these with each other. They have a total order:
+///
+/// aM + b < pM + q iff  a < p or (a == p and b < q).
 /// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0).
 ///
+/// When performing symbolic optimization, sample values will be affine
+/// expressions in M and the symbols. For example, we could have sample values
+/// aM + bS + c and pM + qS + r, where S is a symbol. Now we have
+/// aM + bS + c < pM + qS + r iff (a < p) or (a == p and bS + c < qS + r).
+/// bS + c < qS + r can be always true, always false, or neither,
+/// depending on the set of values S can take. The symbols are always stored
+/// in columns [3, 3 + nSymbols). For more details, see the
+/// documentation for SymbolicLexSimplex.
+///
 /// Initially all the constraints to be added are added as rows, with no attempt
 /// to keep the tableau consistent. Pivots are only performed when some query
 /// is made, such as a call to getRationalLexMin. Care is taken to always
 /// maintain a lexicopositive basis transform, explained below.
 ///
-/// Let the variables be x = (x_1, ... x_n). Let the basis unknowns at a
-/// particular point be  y = (y_1, ... y_n). We know that x = A*y + b for some
-/// n x n matrix A and n x 1 column vector b. We want every column in A to be
-/// lexicopositive, i.e., have at least one non-zero element, with the first
-/// such element being positive. This property is preserved throughout the
-/// operation of LexSimplex. Note that on construction, the basis transform A is
-/// the indentity matrix and so every column is lexicopositive. Note that for
-/// LexSimplex, for the tableau to be consistent we must have non-negative
-/// sample values not only for the constraints but also for the variables.
-/// So if the tableau is consistent then x >= 0 and y >= 0, by which we mean
-/// every element in these vectors is non-negative. (note that this is a
-/// 
diff erent concept from lexicopositivity!)
-///
-/// When we arrive at a basis such the basis transform is lexicopositive and the
-/// tableau is consistent, the sample point is the lexiographically minimum
-/// point in the polytope. We will show that A*y is zero or lexicopositive when
-/// y >= 0. Adding a lexicopositive vector to b will make it lexicographically
-/// bigger, so A*y + b is lexicographically bigger than b for any y >= 0 except
-/// y = 0. This shows that no point lexicographically smaller than x = b can be
-/// obtained. Since we already know that x = b is valid point in the space, this
-/// shows that x = b is the lexicographic minimum.
-///
-/// Proof that A*y is lexicopositive or zero when y > 0. Recall that every
-/// column of A is lexicopositive. Begin by considering A_1, the first row of A.
-/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
-/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
-/// encounter some row A_i that has a non-zero element. Every column is
-/// lexicopositive and so has some positive element before any negative elements
-/// occur, so the element in this row for any column, if non-zero, must be
-/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
-/// non-negative, so if this is non-zero then it must be positive. Then the
-/// first non-zero element of A*y is positive so A*y is lexicopositive.
-///
-/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
-/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
-/// and we can completely ignore these columns of A. We now continue downwards,
-/// looking for rows of A that have a non-zero element other than in the ignored
-/// columns. If we find one, say A_k, once again these elements must be positive
-/// since they are the first non-zero element in each of these columns, so if
-/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
-/// ignore more columns; eventually if all these dot products become zero then
-/// A*y is zero and we are done.
+/// Let the variables be x = (x_1, ... x_n).
+/// Let the symbols be   s = (s_1, ... s_m). Let the basis unknowns at a
+/// particular point be  y = (y_1, ... y_n). We know that x = A*y + T*s + b for
+/// some n x n matrix A, n x m matrix s, and n x 1 column vector b. We want
+/// every column in A to be lexicopositive, i.e., have at least one non-zero
+/// element, with the first such element being positive. This property is
+/// preserved throughout the operation of LexSimplexBase. Note that on
+/// construction, the basis transform A is the identity matrix and so every
+/// column is lexicopositive. Note that for LexSimplexBase, for the tableau to
+/// be consistent we must have non-negative sample values not only for the
+/// constraints but also for the variables. So if the tableau is consistent then
+/// x >= 0 and y >= 0, by which we mean every element in these vectors is
+/// non-negative. (note that this is a 
diff erent concept from lexicopositivity!)
 class LexSimplexBase : public SimplexBase {
 public:
   ~LexSimplexBase() override = default;
@@ -435,25 +431,37 @@ class LexSimplexBase : public SimplexBase {
   unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
 
 protected:
-  LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {}
+  LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
+      : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
   explicit LexSimplexBase(const IntegerRelation &constraints)
-      : LexSimplexBase(constraints.getNumIds()) {
+      : LexSimplexBase(constraints.getNumIds(),
+                       constraints.getIdKindOffset(IdKind::Symbol),
+                       constraints.getNumSymbolIds()) {
     intersectIntegerRelation(constraints);
   }
 
+  /// Add new symbolic variables to the end of the list of variables.
+  void appendSymbol();
+
   /// Try to move the specified row to column orientation while preserving the
-  /// lexicopositivity of the basis transform. If this is not possible, return
-  /// failure. This only occurs when the constraints have no solution; the
-  /// tableau will be marked empty in such a case.
+  /// lexicopositivity of the basis transform. The row must have a negative
+  /// sample value. If this is not possible, return failure. This only occurs
+  /// when the constraints have no solution; the tableau will be marked empty in
+  /// such a case.
   LogicalResult moveRowUnknownToColumn(unsigned row);
 
-  /// Given a row that has a non-integer sample value, add an inequality such
-  /// that this fractional sample value is cut away from the polytope. The added
-  /// inequality will be such that no integer points are removed.
+  /// Given a row that has a non-integer sample value, add an inequality to cut
+  /// away this fractional sample value from the polytope without removing any
+  /// integer points. The integer lexmin, if one existed, remains the same on
+  /// return.
   ///
-  /// Returns whether the cut constraint could be enforced, i.e. failure if the
-  /// cut made the polytope empty, and success if it didn't. Failure status
-  /// indicates that the polytope didn't have any integer points.
+  /// This assumes that the symbolic part of the sample is integral,
+  /// i.e., if the symbolic sample is (c + aM + b_1*s_1 + ... b_n*s_n)/d,
+  /// where s_1, ... s_n are symbols, this assumes that
+  /// (b_1*s_1 + ... + b_n*s_n)/s is integral.
+  ///
+  /// Return failure if the tableau became empty, and success if it didn't.
+  /// Failure status indicates that the polytope was integer empty.
   LogicalResult addCut(unsigned row);
 
   /// Undo the addition of the last constraint. This is only called while
@@ -461,14 +469,19 @@ class LexSimplexBase : public SimplexBase {
   void undoLastConstraint() final;
 
   /// Given two potential pivot columns for a row, return the one that results
-  /// in the lexicographically smallest sample vector.
+  /// in the lexicographically smallest sample vector. The row's sample value
+  /// must be negative. If symbols are involved, the sample value must be
+  /// negative for all possible assignments to the symbols.
   unsigned getLexMinPivotColumn(unsigned row, unsigned colA,
                                 unsigned colB) const;
 };
 
+/// A class for lexicographic optimization without any symbols. This also
+/// provides support for integer-exact redundancy and separateness checks.
 class LexSimplex : public LexSimplexBase {
 public:
-  explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {}
+  explicit LexSimplex(unsigned nVar)
+      : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
   explicit LexSimplex(const IntegerRelation &constraints)
       : LexSimplexBase(constraints) {
     assert(constraints.getNumSymbolIds() == 0 &&
@@ -502,7 +515,7 @@ class LexSimplex : public LexSimplexBase {
   MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;
 
   /// Make the tableau configuration consistent.
-  void restoreRationalConsistency();
+  LogicalResult restoreRationalConsistency();
 
   /// Return whether the specified row is violated;
   bool rowIsViolated(unsigned row) const;
@@ -514,11 +527,122 @@ class LexSimplex : public LexSimplexBase {
   /// Get a row corresponding to a var that has a non-integral sample value, if
   /// one exists. Otherwise, return an empty optional.
   Optional<unsigned> maybeGetNonIntegralVarRow() const;
+};
 
-  /// Given two potential pivot columns for a row, return the one that results
-  /// in the lexicographically smallest sample vector.
-  unsigned getLexMinPivotColumn(unsigned row, unsigned colA,
-                                unsigned colB) const;
+/// Represents the result of a symbolic lexicographic minimization computation.
+struct SymbolicLexMin {
+  SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
+      : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
+        unboundedDomain(
+            PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
+
+  /// This maps assignments of symbols to the corresponding lexmin.
+  /// Takes no value when no integer sample exists for the assignment or if the
+  /// lexmin is unbounded.
+  PWMAFunction lexmin;
+  /// Contains all assignments to the symbols that made the lexmin unbounded.
+  /// Note that the symbols of the input set to the symbolic lexmin are dims
+  /// of this PrebsurgerSet.
+  PresburgerSet unboundedDomain;
+};
+
+/// A class to perform symbolic lexicographic optimization,
+/// i.e., to find, for every assignment to the symbols the specified
+/// `symbolDomain`, the lexicographically minimum value integer value attained
+/// by the non-symbol variables.
+///
+/// The input is a set parametrized by some symbols, i.e., the constant terms
+/// of the constraints in the set are affine expressions in the symbols, and
+/// every assignment to the symbols defines a non-symbolic set.
+///
+/// Accordingly, the sample values of the rows in our tableau will be affine
+/// expressions in the symbols, and every assignment to the symbols will define
+/// a non-symbolic LexSimplex. We then run the algorithm of
+/// LexSimplex::findIntegerLexMin simultaneously for every value of the symbols
+/// in the domain.
+///
+/// Often, the pivot to be performed is the same for all values of the symbols,
+/// in which case we just do it. For example, if the symbolic sample of a row is
+/// negative for all values in the symbol domain, the row needs to be pivoted
+/// irrespective of the precise value of the symbols. To answer queries like
+/// "Is this symbolic sample always negative in the symbol domain?", we maintain
+/// a `LexSimplex domainSimplex` correponding to the symbol domain.
+///
+/// In other cases, it may be that the symbolic sample is violated at some
+/// values in the symbol domain and not violated at others. In this case,
+/// the pivot to be performed does depend on the value of the symbols. We
+/// handle this by splitting the symbol domain. We run the algorithm for the
+/// case where the row isn't violated, and then come back and run the case
+/// where it is.
+class SymbolicLexSimplex : public LexSimplexBase {
+public:
+  /// `constraints` is the set for which the symbolic lexmin will be computed.
+  /// `symbolDomain` is the set of values of the symbols for which the lexmin
+  /// will be computed. `symbolDomain` should have a dim id for every symbol in
+  /// `constraints`, and no other ids.
+  SymbolicLexSimplex(const IntegerPolyhedron &constraints,
+                     const IntegerPolyhedron &symbolDomain)
+      : LexSimplexBase(constraints), domainPoly(symbolDomain),
+        domainSimplex(symbolDomain) {
+    assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
+    assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
+  }
+
+  /// The lexmin will be stored as a function `lexmin` from symbols to
+  /// non-symbols in the result.
+  ///
+  /// For some values of the symbols, the lexmin may be unbounded.
+  /// These parts of the symbol domain will be stored in `unboundedDomain`.
+  SymbolicLexMin computeSymbolicIntegerLexMin();
+
+private:
+  /// Perform all pivots that do not require branching.
+  ///
+  /// Return failure if the tableau became empty, indicating that the polytope
+  /// is always integer empty in the current symbol domain.
+  /// Return success otherwise.
+  LogicalResult doNonBranchingPivots();
+
+  /// Get a row that is always violated in the current domain, if one exists.
+  Optional<unsigned> maybeGetAlwaysViolatedRow();
+
+  /// Get a row corresponding to a variable with non-integral sample value, if
+  /// one exists.
+  Optional<unsigned> maybeGetNonIntegralVarRow();
+
+  /// Given a row that has a non-integer sample value, cut away this fractional
+  /// sample value witahout removing any integer points, i.e., the integer
+  /// lexmin, if it exists, remains the same after a call to this function. This
+  /// may add constraints or local variables to the tableau, as well as to the
+  /// domain.
+  ///
+  /// Returns whether the cut constraint could be enforced, i.e. failure if the
+  /// cut made the polytope empty, and success if it didn't. Failure status
+  /// indicates that the polytope is always integer empty in the symbol domain
+  /// at the time of the call. (This function may modify the symbol domain, but
+  /// failure statu indicates that the polytope was empty for all symbol values
+  /// in the initial domain.)
+  LogicalResult addSymbolicCut(unsigned row);
+
+  /// Get the numerator of the symbolic sample of the specific row.
+  /// This is an affine expression in the symbols with integer coefficients.
+  /// The last element is the constant term. This ignores the big M coefficient.
+  SmallVector<int64_t, 8> getSymbolicSampleNumerator(unsigned row) const;
+
+  /// Return whether all the coefficients of the symbolic sample are integers.
+  ///
+  /// This does not consult the domain to check if the specified expression
+  /// is always integral despite coefficients being fractional.
+  bool isSymbolicSampleIntegral(unsigned row) const;
+
+  /// Record a lexmin. The tableau must be consistent with all variables
+  /// having symbolic samples with integer coefficients.
+  void recordOutput(SymbolicLexMin &result) const;
+
+  /// The symbol domain.
+  IntegerPolyhedron domainPoly;
+  /// Simplex corresponding to the symbol domain.
+  LexSimplex domainSimplex;
 };
 
 /// The Simplex class uses the Normal pivot rule and supports integer emptiness
@@ -540,7 +664,9 @@ class Simplex : public SimplexBase {
   enum class Direction { Up, Down };
 
   Simplex() = delete;
-  explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {}
+  explicit Simplex(unsigned nVar)
+      : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
+                    /*nSymbol=*/0) {}
   explicit Simplex(const IntegerRelation &constraints)
       : Simplex(constraints.getNumIds()) {
     intersectIntegerRelation(constraints);

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index bfa9a6539077d..5e527b5467f54 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/LinearTransform.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
@@ -145,6 +146,21 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
   removeEqualityRange(counts.getNumEqs(), getNumEqualities());
 }
 
+SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
+  // Compute the symbolic lexmin of the dims and locals, with the symbols being
+  // the actual symbols of this set.
+  SymbolicLexMin result =
+      SymbolicLexSimplex(
+          *this, PresburgerSpace::getSetSpace(/*numDims=*/getNumSymbolIds()))
+          .computeSymbolicIntegerLexMin();
+
+  // We want to return only the lexmin over the dims, so strip the locals from
+  // the computed lexmin.
+  result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
+                               getNumLocalIds());
+  return result;
+}
+
 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
   assert(pos <= getNumIdKind(kind));
 

diff  --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 219d490e7368a..680e4509b7cc8 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -66,6 +66,14 @@ unsigned Matrix::appendExtraRow() {
   return nRows - 1;
 }
 
+unsigned Matrix::appendExtraRow(ArrayRef<int64_t> elems) {
+  assert(elems.size() == nColumns && "elems must match row length!");
+  unsigned row = appendExtraRow();
+  for (unsigned col = 0; col < nColumns; ++col)
+    at(row, col) = elems[col];
+  return row;
+}
+
 void Matrix::resizeHorizontally(unsigned newNColumns) {
   if (newNColumns < nColumns)
     removeColumns(newNColumns, nColumns - newNColumns);

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index b995bc00a19c8..711e99aab35b4 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -114,6 +114,18 @@ void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
   IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
 }
 
+void MultiAffineFunction::truncateOutput(unsigned count) {
+  assert(count <= output.getNumRows());
+  output.resizeVertically(count);
+}
+
+void PWMAFunction::truncateOutput(unsigned count) {
+  assert(count <= numOutputs);
+  for (MultiAffineFunction &piece : pieces)
+    piece.truncateOutput(count);
+  numOutputs = count;
+}
+
 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
     MultiAffineFunction other) const {
   if (!isSpaceCompatible(other))

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 57e8f485742d2..f3bf42f40b177 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -18,15 +18,24 @@ using Direction = Simplex::Direction;
 
 const int nullIndex = std::numeric_limits<int>::max();
 
-SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
+                         unsigned nSymbol)
     : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
-      nRedundant(0), tableau(0, nCol), empty(false) {
+      nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) {
+  assert(symbolOffset + nSymbol <= nVar);
+
   colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
   for (unsigned i = 0; i < nVar; ++i) {
     var.emplace_back(Orientation::Column, /*restricted=*/false,
                      /*pos=*/getNumFixedCols() + i);
     colUnknown.push_back(i);
   }
+
+  // Move the symbols to be in columns [3, 3 + nSymbol).
+  for (unsigned i = 0; i < nSymbol; ++i) {
+    var[symbolOffset + i].isSymbol = true;
+    swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
+  }
 }
 
 const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const {
@@ -96,9 +105,13 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
     // where M is the big M parameter. As such, when the user tries to add
     // a row ax + by + cz + d, we express it in terms of our internal variables
     // as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d.
+    //
+    // Symbols don't use the big M parameter since they do not get lex
+    // optimized.
     int64_t bigMCoeff = 0;
     for (unsigned i = 0; i < coeffs.size() - 1; ++i)
-      bigMCoeff -= coeffs[i];
+      if (!var[i].isSymbol)
+        bigMCoeff -= coeffs[i];
     // The coefficient to the big M parameter is stored in column 2.
     tableau(nRow - 1, 2) = bigMCoeff;
   }
@@ -164,19 +177,97 @@ Direction flippedDirection(Direction direction) {
 }
 } // namespace
 
+/// We simply make the tableau consistent while maintaining a lexicopositive
+/// basis transform, and then return the sample value. If the tableau becomes
+/// empty, we return empty.
+///
+/// Let the variables be x = (x_1, ... x_n).
+/// Let the basis unknowns be y = (y_1, ... y_n).
+/// We have that x = A*y + b for some n x n matrix A and n x 1 column vector b.
+///
+/// As we will show below, A*y is either zero or lexicopositive.
+/// Adding a lexicopositive vector to b will make it lexicographically
+/// greater, so A*y + b is always equal to or lexicographically greater than b.
+/// Thus, since we can attain x = b, that is the lexicographic minimum.
+///
+/// We have that that every column in A is lexicopositive, i.e., has at least
+/// one non-zero element, with the first such element being positive. Since for
+/// the tableau to be consistent we must have non-negative sample values not
+/// only for the constraints but also for the variables, we also have x >= 0 and
+/// y >= 0, by which we mean every element in these vectors is non-negative.
+///
+/// Proof that if every column in A is lexicopositive, and y >= 0, then
+/// A*y is zero or lexicopositive. Begin by considering A_1, the first row of A.
+/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
+/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
+/// encounter some row A_i that has a non-zero element. Every column is
+/// lexicopositive and so has some positive element before any negative elements
+/// occur, so the element in this row for any column, if non-zero, must be
+/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
+/// non-negative, so if this is non-zero then it must be positive. Then the
+/// first non-zero element of A*y is positive so A*y is lexicopositive.
+///
+/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
+/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
+/// and we can completely ignore these columns of A. We now continue downwards,
+/// looking for rows of A that have a non-zero element other than in the ignored
+/// columns. If we find one, say A_k, once again these elements must be positive
+/// since they are the first non-zero element in each of these columns, so if
+/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
+/// add these to the set of ignored columns and continue to the next row. If we
+/// run out of rows, then A*y is zero and we are done.
 MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
-  restoreRationalConsistency();
+  if (restoreRationalConsistency().failed())
+    return OptimumKind::Empty;
   return getRationalSample();
 }
 
+/// Given a row that has a non-integer sample value, add an inequality such
+/// that this fractional sample value is cut away from the polytope. The added
+/// inequality will be such that no integer points are removed. i.e., the
+/// integer lexmin, if it exists, is the same with and without this constraint.
+///
+/// Let the row be
+/// (c + coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n)/d,
+/// where s_1, ... s_m are the symbols and
+///       y_1, ... y_n are the other basis unknowns.
+///
+/// For this to be an integer, we want
+/// coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n = -c (mod d)
+/// Note that this constraint must always hold, independent of the basis,
+/// becuse the row unknown's value always equals this expression, even if *we*
+/// later compute the sample value from a 
diff erent expression based on a
+/// 
diff erent basis.
+///
+/// Let us assume that M has a factor of d in it. Imposing this constraint on M
+/// does not in any way hinder us from finding a value of M that is big enough.
+/// Moreover, this function is only called when the symbolic part of the sample,
+/// a_1*s_1 + ... + a_m*s_m, is known to be an integer.
+///
+/// Also, we can safely reduce the coefficients modulo d, so we have:
+///
+/// (b_1%d)y_1 + ... + (b_n%d)y_n = (-c%d) + k*d for some integer `k`
+///
+/// Note that all coefficient modulos here are non-negative. Also, all the
+/// unknowns are non-negative here as both constraints and variables are
+/// non-negative in LexSimplexBase. (We used the big M trick to make the
+/// variables non-negative). Therefore, the LHS here is non-negative.
+/// Since 0 <= (-c%d) < d, k is the quotient of dividing the LHS by d and
+/// is therefore non-negative as well.
+///
+/// So we have
+/// ((b_1%d)y_1 + ... + (b_n%d)y_n - (-c%d))/d >= 0.
+///
+/// The constraint is violated when added (it would be useless otherwise)
+/// so we immediately try to move it to a column.
 LogicalResult LexSimplexBase::addCut(unsigned row) {
-  int64_t denom = tableau(row, 0);
+  int64_t d = tableau(row, 0);
   addZeroRow(/*makeRestricted=*/true);
-  tableau(nRow - 1, 0) = denom;
-  tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom);
-  tableau(nRow - 1, 2) = 0; // M has all factors in it.
-  for (unsigned col = 3; col < nCol; ++col)
-    tableau(nRow - 1, col) = mod(tableau(row, col), denom);
+  tableau(nRow - 1, 0) = d;
+  tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d.
+  tableau(nRow - 1, 2) = 0;
+  for (unsigned col = 3 + nSymbol; col < nCol; ++col)
+    tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d.
   return moveRowUnknownToColumn(nRow - 1);
 }
 
@@ -185,7 +276,7 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
     if (u.orientation == Orientation::Column)
       continue;
     // If the sample value is of the form (a/d)M + b/d, we need b to be
-    // divisible by d. We assume M is very large and contains all possible
+    // divisible by d. We assume M contains all possible
     // factors and is divisible by everything.
     unsigned row = u.pos;
     if (tableau(row, 1) % tableau(row, 0) != 0)
@@ -195,28 +286,34 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
 }
 
 MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() {
-  while (!empty) {
-    restoreRationalConsistency();
-    if (empty)
-      return OptimumKind::Empty;
-
-    if (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
-      // Failure occurs when the polytope is integer empty.
-      if (failed(addCut(*maybeRow)))
-        return OptimumKind::Empty;
-      continue;
-    }
+  // We first try to make the tableau consistent.
+  if (restoreRationalConsistency().failed())
+    return OptimumKind::Empty;
 
-    MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
-    assert(!sample.isEmpty() && "If we reached here the sample should exist!");
-    if (sample.isUnbounded())
-      return OptimumKind::Unbounded;
-    return llvm::to_vector<8>(
-        llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
+  // Then, if the sample value is integral, we are done.
+  while (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
+    // Otherwise, for the variable whose row has a non-integral sample value,
+    // we add a cut, a constraint that remove this rational point
+    // while preserving all integer points, thus keeping the lexmin the same.
+    // We then again try to make the tableau with the new constraint
+    // consistent. This continues until the tableau becomes empty, in which
+    // case there is no integer point, or until there are no variables with
+    // non-integral sample values.
+    //
+    // Failure indicates that the tableau became empty, which occurs when the
+    // polytope is integer empty.
+    if (addCut(*maybeRow).failed())
+      return OptimumKind::Empty;
+    if (restoreRationalConsistency().failed())
+      return OptimumKind::Empty;
   }
 
-  // Polytope is integer empty.
-  return OptimumKind::Empty;
+  MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
+  assert(!sample.isEmpty() && "If we reached here the sample should exist!");
+  if (sample.isUnbounded())
+    return OptimumKind::Unbounded;
+  return llvm::to_vector<8>(
+      llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
 }
 
 bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
@@ -228,6 +325,319 @@ bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
 bool LexSimplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
   return isSeparateInequality(getComplementIneq(coeffs));
 }
+
+SmallVector<int64_t, 8>
+SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const {
+  SmallVector<int64_t, 8> sample;
+  sample.reserve(nSymbol + 1);
+  for (unsigned col = 3; col < 3 + nSymbol; ++col)
+    sample.push_back(tableau(row, col));
+  sample.push_back(tableau(row, 1));
+  return sample;
+}
+
+void LexSimplexBase::appendSymbol() {
+  appendVariable();
+  swapColumns(3 + nSymbol, nCol - 1);
+  var.back().isSymbol = true;
+  nSymbol++;
+}
+
+static bool isRangeDivisibleBy(ArrayRef<int64_t> range, int64_t divisor) {
+  assert(divisor > 0 && "divisor must be positive!");
+  return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; });
+}
+
+bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
+  int64_t denom = tableau(row, 0);
+  return tableau(row, 1) % denom == 0 &&
+         isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom);
+}
+
+/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has
+/// a symbolic sample value with fractional coefficients.
+///
+/// Let the row be
+/// (c + coeffM*M + sum_i a_i*s_i + sum_j b_j*y_j)/d,
+/// where s_1, ... s_m are the symbols and
+///       y_1, ... y_n are the other basis unknowns.
+///
+/// As in LexSimplex::addCut, for this to be an integer, we want
+///
+/// coeffM*M + sum_j b_j*y_j = -c + sum_i (-a_i*s_i) (mod d)
+///
+/// This time, a_1*s_1 + ... + a_m*s_m may not be an integer. We find that
+///
+/// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k
+///
+/// where we take a modulo of the whole symbolic expression on the right to
+/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut,
+/// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have
+/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a
+/// division variable
+///
+/// q = ((-c%d) + sum_i (-a_i%d)s_i)/d
+///
+/// to the symbol domain, so the equality becomes
+///
+/// sum_i (b_i%d)y_i = (-c%d) + sum_i (-a_i%d)s_i - q*d + k*d for some integer k
+///
+/// So the cut is
+/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0
+/// This constraint is violated when added so we immediately try to move it to a
+/// column.
+LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
+  int64_t d = tableau(row, 0);
+
+  // Add the division variable `q` described above to the symbol domain.
+  // q = ((-c%d) + sum_i (-a_i%d)s_i)/d.
+  SmallVector<int64_t, 8> domainDivCoeffs;
+  domainDivCoeffs.reserve(nSymbol + 1);
+  for (unsigned col = 3; col < 3 + nSymbol; ++col)
+    domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i
+  domainDivCoeffs.push_back(mod(-tableau(row, 1), d));     // -c%d.
+
+  domainSimplex.addDivisionVariable(domainDivCoeffs, d);
+  domainPoly.addLocalFloorDiv(domainDivCoeffs, d);
+
+  // Update `this` to account for the additional symbol we just added.
+  appendSymbol();
+
+  // Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0.
+  addZeroRow(/*makeRestricted=*/true);
+  tableau(nRow - 1, 0) = d;
+  tableau(nRow - 1, 2) = 0;
+
+  tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d).
+  for (unsigned col = 3; col < 3 + nSymbol - 1; ++col)
+    tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i.
+  tableau(nRow - 1, 3 + nSymbol - 1) = d;                 // q*d.
+
+  for (unsigned col = 3 + nSymbol; col < nCol; ++col)
+    tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i.
+  return moveRowUnknownToColumn(nRow - 1);
+}
+
+void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
+  Matrix output(0, domainPoly.getNumIds() + 1);
+  output.reserveRows(result.lexmin.getNumOutputs());
+  for (const Unknown &u : var) {
+    if (u.isSymbol)
+      continue;
+
+    if (u.orientation == Orientation::Column) {
+      // M + u has a sample value of zero so u has a sample value of -M, i.e,
+      // unbounded.
+      result.unboundedDomain.unionInPlace(domainPoly);
+      return;
+    }
+
+    int64_t denom = tableau(u.pos, 0);
+    if (tableau(u.pos, 2) < denom) {
+      // M + u has a sample value of fM + something, where f < 1, so
+      // u = (f - 1)M + something, which has a negative coefficient for M,
+      // and so is unbounded.
+      result.unboundedDomain.unionInPlace(domainPoly);
+      return;
+    }
+    assert(tableau(u.pos, 2) == denom &&
+           "Coefficient of M should not be greater than 1!");
+
+    SmallVector<int64_t, 8> sample = getSymbolicSampleNumerator(u.pos);
+    for (int64_t &elem : sample) {
+      assert(elem % denom == 0 && "coefficients must be integral!");
+      elem /= denom;
+    }
+    output.appendExtraRow(sample);
+  }
+  result.lexmin.addPiece(domainPoly, output);
+}
+
+Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
+  // First look for rows that are clearly violated just from the big M
+  // coefficient, without needing to perform any simplex queries on the domain.
+  for (unsigned row = 0; row < nRow; ++row)
+    if (tableau(row, 2) < 0)
+      return row;
+
+  for (unsigned row = 0; row < nRow; ++row) {
+    if (tableau(row, 2) > 0)
+      continue;
+    if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) {
+      // Sample numerator always takes negative values in the symbol domain.
+      return row;
+    }
+  }
+  return {};
+}
+
+Optional<unsigned> SymbolicLexSimplex::maybeGetNonIntegralVarRow() {
+  for (const Unknown &u : var) {
+    if (u.orientation == Orientation::Column)
+      continue;
+    assert(!u.isSymbol && "Symbol should not be in row orientation!");
+    if (!isSymbolicSampleIntegral(u.pos))
+      return u.pos;
+  }
+  return {};
+}
+
+/// The non-branching pivots are just the ones moving the rows
+/// that are always violated in the symbol domain.
+LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
+  while (Optional<unsigned> row = maybeGetAlwaysViolatedRow())
+    if (moveRowUnknownToColumn(*row).failed())
+      return failure();
+  return success();
+}
+
+SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
+  SymbolicLexMin result(nSymbol, var.size() - nSymbol);
+
+  /// The algorithm is more naturally expressed recursively, but we implement
+  /// it iteratively here to avoid potential issues with stack overflows in the
+  /// compiler. We explicitly maintain the stack frames in a vector.
+  ///
+  /// To "recurse", we store the current "stack frame", i.e., state variables
+  /// that we will need when we "return", into `stack`, increment `level`, and
+  /// `continue`. To "tail recurse", we just `continue`.
+  /// To "return", we decrement `level` and `continue`.
+  ///
+  /// When there is no stack frame for the current `level`, this indicates that
+  /// we have just "recursed" or "tail recursed". When there does exist one,
+  /// this indicates that we have just "returned" from recursing. There is only
+  /// one point at which non-tail calls occur so we always "return" there.
+  unsigned level = 1;
+  struct StackFrame {
+    int splitIndex;
+    unsigned snapshot;
+    unsigned domainSnapshot;
+    IntegerRelation::CountsSnapshot domainPolyCounts;
+  };
+  SmallVector<StackFrame, 8> stack;
+
+  while (level > 0) {
+    assert(level >= stack.size());
+    if (level > stack.size()) {
+      if (empty || domainSimplex.findIntegerLexMin().isEmpty()) {
+        // No integer points; return.
+        --level;
+        continue;
+      }
+
+      if (doNonBranchingPivots().failed()) {
+        // Could not find pivots for violated constraints; return.
+        --level;
+        continue;
+      }
+
+      unsigned splitRow;
+      SmallVector<int64_t, 8> symbolicSample;
+      for (splitRow = 0; splitRow < nRow; ++splitRow) {
+        if (tableau(splitRow, 2) > 0)
+          continue;
+        assert(tableau(splitRow, 2) == 0 &&
+               "Non-branching pivots should have been handled already!");
+
+        symbolicSample = getSymbolicSampleNumerator(splitRow);
+        if (domainSimplex.isRedundantInequality(symbolicSample))
+          continue;
+
+        // It's neither redundant nor separate, so it takes both positive and
+        // negative values, and hence constitutes a row for which we need to
+        // split the domain and separately run each case.
+        assert(!domainSimplex.isSeparateInequality(symbolicSample) &&
+               "Non-branching pivots should have been handled already!");
+        break;
+      }
+
+      if (splitRow < nRow) {
+        unsigned domainSnapshot = domainSimplex.getSnapshot();
+        IntegerRelation::CountsSnapshot domainPolyCounts =
+            domainPoly.getCounts();
+
+        // First, we consider the part of the domain where the row is not
+        // violated. We don't have to do any pivots for the row in this case,
+        // but we record the additional constraint that defines this part of
+        // the domain.
+        domainSimplex.addInequality(symbolicSample);
+        domainPoly.addInequality(symbolicSample);
+
+        // Recurse.
+        //
+        // On return, the basis as a set is preserved but not the internal
+        // ordering within rows or columns. Thus, we take note of the index of
+        // the Unknown that caused the split, which may be in a 
diff erent
+        // row when we come back from recursing. We will need this to recurse
+        // on the other part of the split domain, where the row is violated.
+        //
+        // Note that we have to capture the index above and not a reference to
+        // the Unknown itself, since the array it lives in might get
+        // reallocated.
+        int splitIndex = rowUnknown[splitRow];
+        unsigned snapshot = getSnapshot();
+        stack.push_back(
+            {splitIndex, snapshot, domainSnapshot, domainPolyCounts});
+        ++level;
+        continue;
+      }
+
+      // The tableau is rationally consistent for the current domain.
+      // Now we look for non-integral sample values and add cuts for them.
+      if (Optional<unsigned> row = maybeGetNonIntegralVarRow()) {
+        if (addSymbolicCut(*row).failed()) {
+          // No integral points; return.
+          --level;
+          continue;
+        }
+
+        // Rerun this level with the added cut constraint (tail recurse).
+        continue;
+      }
+
+      // Record output and return.
+      recordOutput(result);
+      --level;
+      continue;
+    }
+
+    if (level == stack.size()) {
+      // We have "returned" from "recursing".
+      const StackFrame &frame = stack.back();
+      domainPoly.truncate(frame.domainPolyCounts);
+      domainSimplex.rollback(frame.domainSnapshot);
+      rollback(frame.snapshot);
+      const Unknown &u = unknownFromIndex(frame.splitIndex);
+
+      // Drop the frame. We don't need it anymore.
+      stack.pop_back();
+
+      // Now we consider the part of the domain where the unknown `splitIndex`
+      // was negative.
+      assert(u.orientation == Orientation::Row &&
+             "The split row should have been returned to row orientation!");
+      SmallVector<int64_t, 8> splitIneq =
+          getComplementIneq(getSymbolicSampleNumerator(u.pos));
+      if (moveRowUnknownToColumn(u.pos).failed()) {
+        // The unknown can't be made non-negative; return.
+        --level;
+        continue;
+      }
+
+      // The unknown can be made negative; recurse with the corresponding domain
+      // constraints.
+      domainSimplex.addInequality(splitIneq);
+      domainPoly.addInequality(splitIneq);
+
+      // We are now taking care of the second half of the domain and we don't
+      // need to do anything else here after returning, so it's a tail recurse.
+      continue;
+    }
+  }
+
+  return result;
+}
+
 bool LexSimplex::rowIsViolated(unsigned row) const {
   if (tableau(row, 2) < 0)
     return true;
@@ -243,19 +653,20 @@ Optional<unsigned> LexSimplex::maybeGetViolatedRow() const {
   return {};
 }
 
-// We simply look for violated rows and keep trying to move them to column
-// orientation, which always succeeds unless the constraints have no solution
-// in which case we just give up and return.
-void LexSimplex::restoreRationalConsistency() {
-  while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow()) {
-    LogicalResult status = moveRowUnknownToColumn(*maybeViolatedRow);
-    if (failed(status))
-      return;
-  }
+/// We simply look for violated rows and keep trying to move them to column
+/// orientation, which always succeeds unless the constraints have no solution
+/// in which case we just give up and return.
+LogicalResult LexSimplex::restoreRationalConsistency() {
+  if (empty)
+    return failure();
+  while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow())
+    if (moveRowUnknownToColumn(*maybeViolatedRow).failed())
+      return failure();
+  return success();
 }
 
 // Move the row unknown to column orientation while preserving lexicopositivity
-// of the basis transform.
+// of the basis transform. The sample value of the row must be negative.
 //
 // We only consider pivots where the pivot element is positive. Suppose no such
 // pivot exists, i.e., some violated row has no positive coefficient for any
@@ -318,7 +729,7 @@ void LexSimplex::restoreRationalConsistency() {
 // minimizes the change in sample value.
 LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
   Optional<unsigned> maybeColumn;
-  for (unsigned col = 3; col < nCol; ++col) {
+  for (unsigned col = 3 + nSymbol; col < nCol; ++col) {
     if (tableau(row, col) <= 0)
       continue;
     maybeColumn =
@@ -336,6 +747,7 @@ LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
 
 unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
                                               unsigned colB) const {
+  // First, let's consider the non-symbolic case.
   // A pivot causes the following change. (in the diagram the matrix elements
   // are shown as rationals and there is no common denominator used)
   //
@@ -359,7 +771,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
   // (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample
   // value is -s/a.
   //
-  // If the variable is the pivot row, it sampel value goes from s to 0, for a
+  // If the variable is the pivot row, its sample value goes from s to 0, for a
   // change of -s.
   //
   // If the variable is a non-pivot row, its sample value changes from
@@ -373,8 +785,12 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
   // comparisons involved and can be ignored, since -s is strictly positive.
   //
   // Thus we take away this common factor and just return 0, 1/a, 1, or c/a as
-  // appropriate. This allows us to run the entire algorithm without ever having
-  // to fix a value of M.
+  // appropriate. This allows us to run the entire algorithm treating M
+  // symbolically, as the pivot to be performed does not depend on the value
+  // of M, so long as the sample value s is negative. Note that this is not
+  // because of any special feature of M; by the same argument, we ignore the
+  // symbols too. The caller ensure that the sample value s is negative for
+  // all possible values of the symbols.
   auto getSampleChangeCoeffForVar = [this, row](unsigned col,
                                                 const Unknown &u) -> Fraction {
     int64_t a = tableau(row, col);
@@ -489,6 +905,7 @@ void SimplexBase::pivot(Pivot pair) { pivot(pair.row, pair.column); }
 /// element.
 void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) {
   assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column");
+  assert(!unknownFromColumn(pivotCol).isSymbol);
 
   swapRowWithCol(pivotRow, pivotCol);
   std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol));
@@ -778,6 +1195,9 @@ void SimplexBase::undo(UndoLogEntry entry) {
     assert(var.back().orientation == Orientation::Column &&
            "Variable to be removed must be in column orientation!");
 
+    if (var.back().isSymbol)
+      nSymbol--;
+
     // Move this variable to the last column and remove the column from the
     // tableau.
     swapColumns(var.back().pos, nCol - 1);

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 4149d85d8759f..2cb6ada89397a 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -8,6 +8,7 @@
 
 #include "./Utils.h"
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 
 #include <gmock/gmock.h>
@@ -1134,6 +1135,229 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) {
                                   ">= 0, -11*z + 5*y - 3*x + 7 >= 0)"));
 }
 
+void expectSymbolicIntegerLexMin(
+    StringRef polyStr,
+    ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+        expectedLexminRepr,
+    ArrayRef<StringRef> expectedUnboundedDomainRepr) {
+  IntegerPolyhedron poly = parsePoly(polyStr);
+
+  ASSERT_NE(poly.getNumDimIds(), 0u);
+  ASSERT_NE(poly.getNumSymbolIds(), 0u);
+
+  PWMAFunction expectedLexmin =
+      parsePWMAF(/*numInputs=*/poly.getNumSymbolIds(),
+                 /*numOutputs=*/poly.getNumDimIds(), expectedLexminRepr);
+
+  PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings(
+      poly.getNumSymbolIds(), expectedUnboundedDomainRepr);
+
+  SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
+
+  EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin));
+  if (!result.lexmin.isEqual(expectedLexmin)) {
+    llvm::errs() << "got:\n";
+    result.lexmin.dump();
+    llvm::errs() << "expected:\n";
+    expectedLexmin.dump();
+  }
+
+  EXPECT_TRUE(result.unboundedDomain.isEqual(expectedUnboundedDomain));
+  if (!result.unboundedDomain.isEqual(expectedUnboundedDomain))
+    result.unboundedDomain.dump();
+}
+
+void expectSymbolicIntegerLexMin(
+    StringRef polyStr,
+    ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+        result) {
+  expectSymbolicIntegerLexMin(polyStr, result, {});
+}
+
+TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
+  expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)",
+                              {
+                                  {"(a) : ()", {{1, 0}}}, // a
+                              });
+
+  expectSymbolicIntegerLexMin(
+      "(x)[a, b] : (x - a >= 0, x - b >= 0)",
+      {
+          {"(a, b) : (a - b >= 0)", {{1, 0, 0}}},     // a
+          {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)",
+      {
+          {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}},         // a
+          {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}},     // b
+          {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c
+      });
+
+  expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)",
+                              {
+                                  {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a)
+                              });
+
+  expectSymbolicIntegerLexMin(
+      "(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)",
+      {
+          {"(a) : (a >= 0)", {{1, 0}, {0, 0}}},       // (a, 0)
+          {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)",
+      {
+          {"(a, b, c) : (c - a - b >= 0)",
+           {{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)",
+      {
+          {"(a, b, c) : ()",
+           {{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)",
+      {
+          {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0
+          {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}},                 // 1
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= "
+      "0, a + b + x - 1 >= 0)",
+      {
+          {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)",
+           {{0, 0, 0}}},                              // 0
+          {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + "
+      "y + z - 1 >= 0)",
+      {
+          {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)",
+           {{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b)
+          {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)",
+           {{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0)
+      });
+
+  expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)",
+                              {
+                                  {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
+                              });
+
+  expectSymbolicIntegerLexMin(
+      "(q)[a] : (a - 1 - 3*q == 0, q >= 0)",
+      {
+          {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 1, 0}}}, // a floordiv 3
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)",
+      {
+          {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
+          {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)",
+      {
+          {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
+          {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)",
+      {
+          {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
+          {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
+          {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+           {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(x, y, z, w)[g] : ("
+      // x, y, z, w are boolean variables.
+      "1 - x >= 0, x >= 0, 1 - y >= 0, y >= 0,"
+      "1 - z >= 0, z >= 0, 1 - w >= 0, w >= 0,"
+      // We have some constraints on them:
+      "x + y + z - 1 >= 0,"             // x or y or z
+      "x + y + w - 1 >= 0,"             // x or y or w
+      "1 - x + 1 - y + 1 - w - 1 >= 0," // ~x or ~y or ~w
+      // What's the lexmin solution using exactly g true vars?
+      "g - x - y - z - w == 0)",
+      {
+          {"(g) : (g - 1 == 0)",
+           {{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0)
+          {"(g) : (g - 2 == 0)",
+           {{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1)
+          {"(g) : (g - 3 == 0)",
+           {{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1)
+      });
+
+  // Bezout's lemma: if a, b are constants,
+  // the set of values that ax + by can take is all multiples of gcd(a, b).
+  expectSymbolicIntegerLexMin(
+      // If (x, y) is a solution for a given [a, r], then so is (x - 5, y + 2).
+      // So the lexmin is unbounded if it exists.
+      "(x, y)[a, r] : (a >= 0, r - a + 14*x + 35*y == 0)", {},
+      // According to Bezout's lemma, 14x + 35y can take on all multiples
+      // of 7 and no other values. So the solution exists iff r - a is a
+      // multiple of 7.
+      {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"});
+
+  // The lexmins are unbounded.
+  expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {},
+                              {"(a) : ()"});
+
+  // Test cases adapted from isl.
+  expectSymbolicIntegerLexMin(
+      // a = 2b - 2(c - b), c - b >= 0.
+      // So b is minimized when c = b.
+      "(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)",
+      {
+          {"(a) : (a - 2*(a floordiv 2) == 0)",
+           {{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2)
+      });
+
+  expectSymbolicIntegerLexMin(
+      // 0 <= b <= 255, 1 <= a - 512b <= 509,
+      // b + 8 >= 1 + 16*(b + 8 floordiv 16) // i.e. b % 16 != 8
+      "(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= "
+      "0, b + 7 - 16*((8 + b) floordiv 16) >= 0)",
+      {
+          {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv "
+           "512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv "
+           "512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)",
+           {{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2)
+      });
+
+  expectSymbolicIntegerLexMin(
+      "(a, b)[K, N, x, y] : (N - K - 2 >= 0, K + 4 - N >= 0, x - 4 >= 0, x + 6 "
+      "- 2*N >= 0, K+N - x - 1 >= 0, a - N + 1 >= 0, K+N-1-a >= 0,a + 6 - b - "
+      "N >= 0, 2*N - 4 - a >= 0,"
+      "2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 "
+      ">= 0)",
+      {{
+          "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N "
+          ">= 0, N + K - 2 - x >= 0, x - 4 >= 0)",
+          {{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N)
+      }});
+}
+
 static void
 expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
                                       Optional<uint64_t> trueVolume,


        


More information about the Mlir-commits mailing list